From 8edfc1a2b7e27f6bb25fd4a09e5c3770be5fab04 Mon Sep 17 00:00:00 2001 From: Andrew Steinheiser Date: Wed, 25 Jun 2025 14:21:28 -0700 Subject: [PATCH 1/6] Add schema for new responses API (#785) * update contributing * openai v5.6.0 * create schema for responses API * add skeleton route to conversations router * add some basic tests * put package back * put package lock back * add case for message array input * feat: complete create response test suite * add minimum to max_output_tokens * add more test cases * add test case * set zod version to 3.25 for v4 support * start of pr feedback * adjust schema and tests to not allow empty message strings/arrays * move create responses * create new (very basic) router for responses API * basic test for responses router, will be expanded once middleware is added * update test * update func name * create route for responses API -- ensure certain options are set via config * clean up tests by moving MONGO_CHAT_MODEL string to test config * use GenerateResponse type * update comments in router * update reqId stuff * update comment --- CONTRIBUTING.md | 6 +- package-lock.json | 6 +- .../src/config.ts | 13 + .../src/test/testHelpers.ts | 1 - packages/mongodb-chatbot-server/package.json | 1 + packages/mongodb-chatbot-server/src/app.ts | 11 +- .../src/routes/index.ts | 1 + .../routes/responses/createResponse.test.ts | 551 ++++++++++++++++++ .../src/routes/responses/createResponse.ts | 195 +++++++ .../src/routes/responses/index.ts | 1 + .../routes/responses/responsesRouter.test.ts | 50 ++ .../src/routes/responses/responsesRouter.ts | 33 ++ .../src/test/testConfig.ts | 15 + 13 files changed, 877 insertions(+), 7 deletions(-) create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/index.ts create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1c6a81e85..0628a8516 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,7 @@ The monorepo has the following main projects, each of which correspond to a Java These packages power our RAG applications. - `mongodb-rag-core`: A set of common resources (modules, functions, types, etc.) shared across projects. - - You need to recompile `mongodb-rag-core` by running `npm run build` every time you update it for the changes to be accessible in the other projects that dependend on it. + - You need to recompile `mongodb-rag-core` by running `npm run build` every time you update it for the changes to be accessible in the other projects that depend on it. - `mongodb-rag-ingest`: CLI application that takes data from data sources and converts it to `embedded_content` used by Atlas Vector Search. ### MongoDB Chatbot Framework @@ -40,7 +40,7 @@ general, we publish these as reusable packages on npm. These packages are our production chatbot. They build on top of the Chatbot Framework packages and add MongoDB-specific implementations. -- `chatbot-eval-mongodb-public`: Test suites, evaluators, and reports for the MongoDB AI Chatbot +- `chatbot-eval-mongodb-public`: Test suites, evaluators, and reports for the MongoDB AI Chatbot. - `chatbot-server-mongodb-public`: Chatbot server implementation with our MongoDB-specific configuration. - `ingest-mongodb-public`: RAG ingest service configured to ingest MongoDB Docs, DevCenter, MDBU, MongoDB Press, etc. @@ -132,7 +132,7 @@ npm run dev ## Infrastructure -The projects uses Drone for its CI/CD pipeline. All drone config is located in `.drone.yml`. +The projects use Drone for their CI/CD pipeline. All drone configs are located in `.drone.yml`. Applications are deployed on Kubernetes using the Kanopy developer platform. Kubernetes/Kanopy configuration are found in the `/environments` diff --git a/package-lock.json b/package-lock.json index 60da7d2a0..7dfd59e2b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -53333,8 +53333,9 @@ } }, "node_modules/zod": { - "version": "3.25.48", - "license": "MIT", + "version": "3.25.67", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz", + "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==", "funding": { "url": "https://github.com/sponsors/colinhacks" } @@ -54596,6 +54597,7 @@ "rate-limit-mongo": "^2.3.2", "stream-json": "^1.8.0", "winston": "^3.9.0", + "zod": "^3.25.67", "zod-error": "^1.5.0" }, "devDependencies": { diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 5bb7fa705..c2192b2b6 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -386,6 +386,19 @@ export const config: AppConfig = { maxInputLengthCharacters: 3000, braintrustLogger, }, + responsesRouterConfig: { + createResponse: { + supportedModels: ["mongodb-chat-latest"], + maxOutputTokens: 4000, + generateResponse: () => + Promise.resolve({ + messages: [ + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a database." }, + ], + }), + }, + }, maxRequestTimeoutMs: 60000, corsOptions: { origin: allowedOrigins, diff --git a/packages/chatbot-server-mongodb-public/src/test/testHelpers.ts b/packages/chatbot-server-mongodb-public/src/test/testHelpers.ts index eab861ff3..dcbacb139 100644 --- a/packages/chatbot-server-mongodb-public/src/test/testHelpers.ts +++ b/packages/chatbot-server-mongodb-public/src/test/testHelpers.ts @@ -64,7 +64,6 @@ export async function makeTestApp(defaultConfigOverrides?: Partial) { export { systemPrompt }; export { - generateUserPrompt, openAiClient, OPENAI_CHAT_COMPLETION_DEPLOYMENT, OPENAI_RETRIEVAL_EMBEDDING_DEPLOYMENT, diff --git a/packages/mongodb-chatbot-server/package.json b/packages/mongodb-chatbot-server/package.json index a7e10a4d4..dc33559ea 100644 --- a/packages/mongodb-chatbot-server/package.json +++ b/packages/mongodb-chatbot-server/package.json @@ -55,6 +55,7 @@ "rate-limit-mongo": "^2.3.2", "stream-json": "^1.8.0", "winston": "^3.9.0", + "zod": "^3.25.67", "zod-error": "^1.5.0" }, "devDependencies": { diff --git a/packages/mongodb-chatbot-server/src/app.ts b/packages/mongodb-chatbot-server/src/app.ts index c2ae01a5e..41626dd5c 100644 --- a/packages/mongodb-chatbot-server/src/app.ts +++ b/packages/mongodb-chatbot-server/src/app.ts @@ -11,7 +11,9 @@ import "dotenv/config"; import { ConversationsRouterParams, makeConversationsRouter, -} from "./routes/conversations/conversationsRouter"; + ResponsesRouterParams, + makeResponsesRouter, +} from "./routes"; import { logger } from "mongodb-rag-core"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { getRequestId, logRequest, sendErrorResponse } from "./utils"; @@ -27,6 +29,11 @@ export interface AppConfig { */ conversationsRouterConfig: ConversationsRouterParams; + /** + Configuration for the responses router. + */ + responsesRouterConfig: ResponsesRouterParams; + /** Maximum time in milliseconds for a request to complete before timing out. Defaults to 60000 (1 minute). @@ -116,6 +123,7 @@ export const makeApp = async (config: AppConfig): Promise => { const { maxRequestTimeoutMs = DEFAULT_MAX_REQUEST_TIMEOUT_MS, conversationsRouterConfig, + responsesRouterConfig, corsOptions, apiPrefix = DEFAULT_API_PREFIX, expressAppConfig, @@ -140,6 +148,7 @@ export const makeApp = async (config: AppConfig): Promise => { `${apiPrefix}/conversations`, makeConversationsRouter(conversationsRouterConfig) ); + app.use(`${apiPrefix}/responses`, makeResponsesRouter(responsesRouterConfig)); app.get("/health", (_req, res) => { const data = { diff --git a/packages/mongodb-chatbot-server/src/routes/index.ts b/packages/mongodb-chatbot-server/src/routes/index.ts index b9f9da7be..d3e816609 100644 --- a/packages/mongodb-chatbot-server/src/routes/index.ts +++ b/packages/mongodb-chatbot-server/src/routes/index.ts @@ -1 +1,2 @@ export * from "./conversations"; +export * from "./responses"; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts new file mode 100644 index 000000000..f8d527557 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -0,0 +1,551 @@ +import "dotenv/config"; +import request from "supertest"; +import { Express } from "express"; +import { DEFAULT_API_PREFIX } from "../../app"; +import { makeTestApp } from "../../test/testHelpers"; +import { MONGO_CHAT_MODEL } from "../../test/testConfig"; + +jest.setTimeout(100000); + +describe("POST /responses", () => { + const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; + let app: Express; + let ipAddress: string; + let origin: string; + + beforeEach(async () => { + ({ app, ipAddress, origin } = await makeTestApp()); + }); + + describe("Valid requests", () => { + it("Should return 200 given a string input", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 given a message array input", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a document database." }, + { role: "user", content: "What is a document database?" }, + ], + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 given a valid request with instructions", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + instructions: "You are a helpful chatbot.", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with valid max_output_tokens", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + max_output_tokens: 4000, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with valid metadata", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + metadata: { key1: "value1", key2: "value2" }, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with valid temperature", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + temperature: 0, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with previous_response_id", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + previous_response_id: "some-id", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with user", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + user: "some-user-id", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with store=false", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + store: false, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with store=true", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + store: true, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with tools and tool_choice", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tools: [ + { + name: "test-tool", + description: "A tool for testing.", + parameters: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + ], + tool_choice: "auto", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with a specific function tool_choice", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tools: [ + { + name: "test-tool", + description: "A tool for testing.", + parameters: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + ], + tool_choice: { + type: "function", + name: "test-tool", + }, + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 given a message array with function_call", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { role: "user", content: "What is MongoDB?" }, + { + type: "function_call", + id: "call123", + name: "my_function", + arguments: `{"query": "value"}`, + status: "in_progress", + }, + ], + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 given a message array with function_call_output", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { role: "user", content: "What is MongoDB?" }, + { + type: "function_call_output", + call_id: "call123", + output: `{"result": "success"}`, + status: "completed", + }, + ], + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with tool_choice 'none'", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tool_choice: "none", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with tool_choice 'only'", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tool_choice: "only", + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with an empty tools array", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tools: [], + }); + + expect(response.statusCode).toBe(200); + }); + }); + + // TODO: In EAI-1126, we will need to change the error types to match the OpenAI spec + describe("Invalid requests", () => { + it("Should return 400 with an empty input string", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "", + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 with an empty message array", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [], + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if model is not mongodb-chat-latest", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: "gpt-4o-mini", + stream: true, + input: "What is MongoDB?", + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ + error: "Model gpt-4o-mini is not supported.", + }); + }); + + it("Should return 400 if stream is not true", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: false, + input: "What is MongoDB?", + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if max_output_tokens is > 4000", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + max_output_tokens: 4001, + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ + error: + "Max output tokens 4001 is greater than the maximum allowed 4000.", + }); + }); + + it("Should return 400 if metadata has too many fields", async () => { + const metadata: Record = {}; + for (let i = 0; i < 17; i++) { + metadata[`key${i}`] = "value"; + } + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + metadata, + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if metadata value is too long", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + metadata: { key1: "a".repeat(513) }, + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if temperature is not 0", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + temperature: 0.5, + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if messages contain an invalid role", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { role: "user", content: "What is MongoDB?" }, + { role: "invalid-role", content: "This is an invalid role." }, + ], + }); + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if function_call has an invalid status", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { + type: "function_call", + id: "call123", + name: "my_function", + arguments: `{"query": "value"}`, + status: "invalid_status", + }, + ], + }); + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if function_call_output has an invalid status", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: [ + { + type: "function_call_output", + call_id: "call123", + output: `{"result": "success"}`, + status: "invalid_status", + }, + ], + }); + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 with an invalid tool_choice string", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + tool_choice: "invalid_choice", + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + + it("Should return 400 if max_output_tokens is negative", async () => { + const response = await request(app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + max_output_tokens: -1, + }); + + expect(response.statusCode).toBe(400); + expect(response.body).toEqual({ error: "Invalid request" }); + }); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts new file mode 100644 index 000000000..daa41febc --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -0,0 +1,195 @@ +import { z } from "zod"; +import { + Request as ExpressRequest, + Response as ExpressResponse, +} from "express"; +import { RequestError, makeRequestError } from "../conversations/utils"; +import { SomeExpressRequest } from "../../middleware"; +import { getRequestId, sendErrorResponse } from "../../utils"; +import { GenerateResponse } from "../../processors"; + +export const CreateResponseRequestBodySchema = z.object({ + model: z.string(), + instructions: z.string().optional(), + input: z.union([ + z + .string() + .refine((input) => input.length > 0, "Input must be a non-empty string"), + z + .array( + z.union([ + z.object({ + type: z.literal("message").optional(), + role: z.enum(["user", "assistant", "system"]), + content: z.string(), + }), + // function tool call + z.object({ + type: z.literal("function_call"), + id: z + .string() + .optional() + .describe("Unique ID of the function tool call"), + name: z.string().describe("Name of the function tool to call"), + arguments: z + .string() + .describe( + "JSON string of arguments passed to the function tool call" + ), + status: z.enum(["in_progress", "completed", "incomplete"]), + }), + // function tool call output + z.object({ + type: z.literal("function_call_output"), + id: z + .string() + .optional() + .describe("The unique ID of the function tool call output"), + call_id: z + .string() + .describe( + "Unique ID of the function tool call generated by the model" + ), + output: z + .string() + .describe("JSON string of the function tool call"), + status: z.enum(["in_progress", "completed", "incomplete"]), + }), + ]) + ) + .refine((input) => input.length > 0, "Input must be a non-empty array"), + ]), + max_output_tokens: z.number().min(0).default(1000), + metadata: z + .record(z.string(), z.string().max(512)) + .optional() + .refine( + (metadata) => Object.keys(metadata ?? {}).length <= 16, + "Too many metadata fields. Max 16." + ), + previous_response_id: z + .string() + .optional() + .describe("The unique ID of the previous response to the model."), + store: z + .boolean() + .optional() + .describe("Whether to store the response in the conversation.") + .default(true), + stream: z.literal(true, { + errorMap: () => ({ message: "'stream' must be true" }), + }), + temperature: z + .union([ + z.literal(0, { + errorMap: () => ({ message: "Temperature must be 0 or unset" }), + }), + z.undefined(), + ]) + .optional() + .describe("Temperature for the model. Defaults to 0.") + .default(0), + tool_choice: z + .union([ + z.enum(["none", "only", "auto"]), + z + .object({ + name: z.string(), + type: z.literal("function"), + }) + .describe("Function tool choice"), + ]) + .optional() + .describe("Tool choice for the model. Defaults to 'auto'.") + .default("auto"), + tools: z + .array( + z.object({ + name: z.string(), + description: z.string().optional(), + parameters: z + .record(z.string(), z.unknown()) + .describe( + "A JSON schema object describing the parameters of the function." + ), + }) + ) + .optional() + .describe("Tools for the model to use."), + + user: z.string().optional().describe("The user ID of the user."), +}); + +export const CreateResponseRequest = SomeExpressRequest.merge( + z.object({ + headers: z.object({ + "req-id": z.string(), + }), + body: CreateResponseRequestBodySchema, + }) +); + +export type CreateResponseRequest = z.infer; + +export interface CreateResponseRouteParams { + generateResponse: GenerateResponse; + supportedModels: string[]; + maxOutputTokens: number; +} + +export function makeCreateResponseRoute({ + // generateResponse, + supportedModels, + maxOutputTokens, +}: CreateResponseRouteParams) { + return async ( + req: ExpressRequest, + res: ExpressResponse<{ status: string }, any> + ) => { + const reqId = getRequestId(req); + try { + const { + body: { model, max_output_tokens }, + } = req; + + // --- MODEL CHECK --- + if (!supportedModels.includes(model)) { + throw makeRequestError({ + httpStatus: 400, + message: `Model ${model} is not supported.`, + }); + } + + // --- MAX OUTPUT TOKENS CHECK --- + if (max_output_tokens > maxOutputTokens) { + throw makeRequestError({ + httpStatus: 400, + message: `Max output tokens ${max_output_tokens} is greater than the maximum allowed ${maxOutputTokens}.`, + }); + } + + // TODO: actually use this call + // generateResponse(); + // TODO: do something with maxOutputTokens (validate result length or pass to generateResponse?) + + return res.status(200).send({ status: "ok" }); + } catch (error) { + // TODO: better error handling, in line with the Responses API + const { httpStatus, message } = + (error as Error).name === "RequestError" + ? (error as RequestError) + : makeRequestError({ + message: (error as Error).message, + stack: (error as Error).stack, + httpStatus: 500, + }); + + sendErrorResponse({ + res, + reqId, + httpStatus, + errorMessage: message, + }); + } + }; +} diff --git a/packages/mongodb-chatbot-server/src/routes/responses/index.ts b/packages/mongodb-chatbot-server/src/routes/responses/index.ts new file mode 100644 index 000000000..a0523d4ea --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/index.ts @@ -0,0 +1 @@ +export * from "./responsesRouter"; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts new file mode 100644 index 000000000..8f6e156c0 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts @@ -0,0 +1,50 @@ +import request from "supertest"; +import { AppConfig } from "../../app"; +import { DEFAULT_API_PREFIX } from "../../app"; +import { makeTestApp } from "../../test/testHelpers"; +import { makeTestAppConfig } from "../../test/testHelpers"; +import { MONGO_CHAT_MODEL } from "../../test/testConfig"; + +jest.setTimeout(60000); + +describe("Responses Router", () => { + const ipAddress = "127.0.0.1"; + const responsesEndpoint = DEFAULT_API_PREFIX + "/responses"; + const validRequestBody = { + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", + }; + let appConfig: AppConfig; + + beforeAll(async () => { + ({ appConfig } = await makeTestAppConfig()); + }); + + it("should return 200 given a valid request", async () => { + const { app, origin } = await makeTestApp({ + ...appConfig, + responsesRouterConfig: { + createResponse: { + supportedModels: [MONGO_CHAT_MODEL], + maxOutputTokens: 4000, + generateResponse: () => + Promise.resolve({ + messages: [ + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a database." }, + ], + }), + }, + }, + }); + + const res = await request(app) + .post(responsesEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send(validRequestBody); + + expect(res.status).toBe(200); + }); +}); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts new file mode 100644 index 000000000..9079379b6 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts @@ -0,0 +1,33 @@ +import Router from "express-promise-router"; +import validateRequestSchema from "../../middleware/validateRequestSchema"; +import { + makeCreateResponseRoute, + CreateResponseRequest, +} from "./createResponse"; +import { GenerateResponse } from "../../processors"; + +export interface ResponsesRouterParams { + createResponse: { + generateResponse: GenerateResponse; + supportedModels: string[]; + maxOutputTokens: number; + }; +} + +/** + Constructor function to make the /responses/* Express.js router. + */ +export function makeResponsesRouter({ createResponse }: ResponsesRouterParams) { + const responsesRouter = Router(); + + // TODO: add rate limit config + + // Create Response API + responsesRouter.post( + "/", + validateRequestSchema(CreateResponseRequest), + makeCreateResponseRoute(createResponse) + ); + + return responsesRouter; +} diff --git a/packages/mongodb-chatbot-server/src/test/testConfig.ts b/packages/mongodb-chatbot-server/src/test/testConfig.ts index 100757731..3f7337041 100644 --- a/packages/mongodb-chatbot-server/src/test/testConfig.ts +++ b/packages/mongodb-chatbot-server/src/test/testConfig.ts @@ -171,6 +171,8 @@ export const mockGenerateResponse: GenerateResponse = async ({ }; }; +export const MONGO_CHAT_MODEL = "mongodb-chat-latest"; + export async function makeDefaultConfig(): Promise { const conversations = makeMongoDbConversationsService(memoryDb); return { @@ -178,6 +180,19 @@ export async function makeDefaultConfig(): Promise { generateResponse: mockGenerateResponse, conversations, }, + responsesRouterConfig: { + createResponse: { + supportedModels: [MONGO_CHAT_MODEL], + maxOutputTokens: 4000, + generateResponse: () => + Promise.resolve({ + messages: [ + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a database." }, + ], + }), + }, + }, maxRequestTimeoutMs: 30000, corsOptions: { origin: allowedOrigins, From 47ee7bf766d9f9117f70fd667b45c724c96550d7 Mon Sep 17 00:00:00 2001 From: Andrew Steinheiser Date: Fri, 27 Jun 2025 14:59:44 -0700 Subject: [PATCH 2/6] Improve responses api errors (#789) * create errors helper * update errors to include http code * add sendErrorResponse helper * add enum for error codes * update createResponse to use new error helpers * handle input validation in createResponse vs middleware * mostly update tests * improve error messages from zod * update tests * adjust variable names/exports * add test case for unknown errors * remove openai dep from mongodb-chat-server in favor of importing from rag-core * update errors to use openai types and classes * update tests with new validation errors * improve router test * Add rate limiting to responses api (#792) * add rate limit and global slowdown to responses router * basic working rate limit with test for responses router * add more test assertions for openai error * configure rateLimit middleware properly with makeRateLimitError helper * update test case * update test case * use sendErrorResponse helper within rateLimit middleware to ensure we get logging there too * remove extra comment * update error message * abstract error strings for tests --- package-lock.json | 382 ++---------------- packages/mongodb-chatbot-server/package.json | 1 - .../routes/responses/createResponse.test.ts | 82 +++- .../src/routes/responses/createResponse.ts | 98 +++-- .../src/routes/responses/errors.ts | 125 ++++++ .../routes/responses/responsesRouter.test.ts | 108 +++++ .../src/routes/responses/responsesRouter.ts | 51 ++- .../src/test/testHelpers.ts | 7 +- 8 files changed, 432 insertions(+), 422 deletions(-) create mode 100644 packages/mongodb-chatbot-server/src/routes/responses/errors.ts diff --git a/package-lock.json b/package-lock.json index 7dfd59e2b..6b89981de 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8813,13 +8813,6 @@ "uuid": "dist/bin/uuid" } }, - "node_modules/@langchain/openai/node_modules/@types/node": { - "version": "18.19.45", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, "node_modules/@langchain/openai/node_modules/ansi-styles": { "version": "5.2.0", "license": "MIT", @@ -8840,10 +8833,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/@langchain/openai/node_modules/form-data-encoder": { - "version": "1.7.2", - "license": "MIT" - }, "node_modules/@langchain/openai/node_modules/langsmith": { "version": "0.1.42", "license": "MIT", @@ -8872,34 +8861,6 @@ } } }, - "node_modules/@langchain/openai/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, "node_modules/@langchain/openai/node_modules/semver": { "version": "7.6.3", "license": "ISC", @@ -23819,15 +23780,6 @@ "node": ">=18" } }, - "node_modules/braintrust/node_modules/@types/node": { - "version": "18.19.43", - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "undici-types": "~5.26.4" - } - }, "node_modules/braintrust/node_modules/ai": { "version": "3.4.33", "license": "Apache-2.0", @@ -23933,12 +23885,6 @@ "@esbuild/win32-x64": "0.25.1" } }, - "node_modules/braintrust/node_modules/form-data-encoder": { - "version": "1.7.2", - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/braintrust/node_modules/graceful-fs": { "version": "4.2.11", "license": "ISC" @@ -23956,36 +23902,6 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/braintrust/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "optional": true, - "peer": true, - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, "node_modules/braintrust/node_modules/source-map": { "version": "0.7.4", "license": "BSD-3-Clause", @@ -34208,10 +34124,6 @@ "node": ">= 12" } }, - "node_modules/llamaindex/node_modules/form-data-encoder": { - "version": "1.7.2", - "license": "MIT" - }, "node_modules/llamaindex/node_modules/js-base64": { "version": "3.7.7", "license": "BSD-3-Clause" @@ -34238,63 +34150,6 @@ "url": "https://opencollective.com/node-fetch" } }, - "node_modules/llamaindex/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, - "node_modules/llamaindex/node_modules/openai/node_modules/@types/node": { - "version": "18.19.70", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, - "node_modules/llamaindex/node_modules/openai/node_modules/node-fetch": { - "version": "2.7.0", - "license": "MIT", - "dependencies": { - "whatwg-url": "^5.0.0" - }, - "engines": { - "node": "4.x || >=6.0.0" - }, - "peerDependencies": { - "encoding": "^0.1.0" - }, - "peerDependenciesMeta": { - "encoding": { - "optional": true - } - } - }, - "node_modules/llamaindex/node_modules/openai/node_modules/undici-types": { - "version": "5.26.5", - "license": "MIT" - }, "node_modules/llamaindex/node_modules/qs": { "version": "6.13.1", "license": "BSD-3-Clause", @@ -34331,26 +34186,10 @@ "safe-buffer": "~5.2.0" } }, - "node_modules/llamaindex/node_modules/tr46": { - "version": "0.0.3", - "license": "MIT" - }, "node_modules/llamaindex/node_modules/undici-types": { "version": "6.19.8", "license": "MIT" }, - "node_modules/llamaindex/node_modules/webidl-conversions": { - "version": "3.0.1", - "license": "BSD-2-Clause" - }, - "node_modules/llamaindex/node_modules/whatwg-url": { - "version": "5.0.0", - "license": "MIT", - "dependencies": { - "tr46": "~0.0.3", - "webidl-conversions": "^3.0.0" - } - }, "node_modules/load-json-file": { "version": "6.2.0", "dev": true, @@ -41518,20 +41357,47 @@ } }, "node_modules/openai": { - "version": "3.3.0", - "license": "MIT", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "dependencies": { - "axios": "^0.26.0", - "form-data": "^4.0.0" + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + }, + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } } }, - "node_modules/openai/node_modules/axios": { - "version": "0.26.1", - "license": "MIT", + "node_modules/openai/node_modules/@types/node": { + "version": "18.19.112", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.112.tgz", + "integrity": "sha512-i+Vukt9POdS/MBI7YrrkkI5fMfwFtOjphSmt4WXYLfwqsfr6z/HdCx7LqT9M7JktGob8WNgj8nFB4TbGNE4Cog==", "dependencies": { - "follow-redirects": "^1.14.8" + "undici-types": "~5.26.4" } }, + "node_modules/openai/node_modules/form-data-encoder": { + "version": "1.7.2", + "resolved": "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz", + "integrity": "sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==" + }, "node_modules/openapi-types": { "version": "12.1.3", "license": "MIT" @@ -53490,10 +53356,6 @@ "zod-to-json-schema": "^3.22.5" } }, - "packages/benchmarks/node_modules/form-data-encoder": { - "version": "1.7.2", - "license": "MIT" - }, "packages/benchmarks/node_modules/js-yaml": { "version": "4.1.0", "license": "MIT", @@ -53508,30 +53370,6 @@ "version": "1.0.0", "license": "MIT" }, - "packages/benchmarks/node_modules/openai": { - "version": "4.47.1", - "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7", - "web-streams-polyfill": "^3.2.1" - }, - "bin": { - "openai": "bin/cli" - } - }, - "packages/benchmarks/node_modules/openai/node_modules/@types/node": { - "version": "18.19.86", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/benchmarks/node_modules/yaml": { "version": "2.7.1", "license": "ISC", @@ -54592,7 +54430,6 @@ "langchain": "^0.2.9", "lodash.clonedeep": "^4.5.0", "mongodb-rag-core": "*", - "openai": "^3.2.1", "pm2": "^5.3.0", "rate-limit-mongo": "^2.3.2", "stream-json": "^1.8.0", @@ -54698,36 +54535,6 @@ "uuid": "dist/bin/uuid" } }, - "packages/mongodb-chatbot-server/node_modules/@langchain/core/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "optional": true, - "peer": true, - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, "packages/mongodb-chatbot-server/node_modules/@langchain/core/node_modules/uuid": { "version": "10.0.0", "funding": [ @@ -54739,15 +54546,6 @@ "uuid": "dist/bin/uuid" } }, - "packages/mongodb-chatbot-server/node_modules/@types/node": { - "version": "18.19.86", - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/mongodb-chatbot-server/node_modules/ansi-styles": { "version": "5.2.0", "license": "MIT", @@ -54772,12 +54570,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "packages/mongodb-chatbot-server/node_modules/form-data-encoder": { - "version": "1.7.2", - "license": "MIT", - "optional": true, - "peer": true - }, "packages/mongodb-chatbot-server/node_modules/ip-address": { "version": "8.1.0", "license": "MIT", @@ -55065,36 +54857,6 @@ "uuid": "dist/bin/uuid" } }, - "packages/mongodb-chatbot-server/node_modules/langchain/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "optional": true, - "peer": true, - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, "packages/mongodb-chatbot-server/node_modules/langchain/node_modules/uuid": { "version": "10.0.0", "funding": [ @@ -58249,41 +58011,6 @@ "node": ">=16 || 14 >=14.17" } }, - "packages/mongodb-rag-core/node_modules/openai": { - "version": "4.95.0", - "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, - "packages/mongodb-rag-core/node_modules/openai/node_modules/@types/node": { - "version": "18.19.61", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/mongodb-rag-core/node_modules/path-scurry": { "version": "2.0.0", "license": "BlueOak-1.0.0", @@ -60638,45 +60365,6 @@ "uuid": "dist/bin/uuid" } }, - "packages/release-notes-generator/node_modules/openai": { - "version": "4.104.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", - "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", - "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, - "bin": { - "openai": "bin/cli" - }, - "peerDependencies": { - "ws": "^8.18.0", - "zod": "^3.23.8" - }, - "peerDependenciesMeta": { - "ws": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/openai/node_modules/@types/node": { - "version": "18.19.112", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.112.tgz", - "integrity": "sha512-i+Vukt9POdS/MBI7YrrkkI5fMfwFtOjphSmt4WXYLfwqsfr6z/HdCx7LqT9M7JktGob8WNgj8nFB4TbGNE4Cog==", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/release-notes-generator/node_modules/path-scurry": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-2.0.0.tgz", diff --git a/packages/mongodb-chatbot-server/package.json b/packages/mongodb-chatbot-server/package.json index dc33559ea..1b996953f 100644 --- a/packages/mongodb-chatbot-server/package.json +++ b/packages/mongodb-chatbot-server/package.json @@ -50,7 +50,6 @@ "langchain": "^0.2.9", "lodash.clonedeep": "^4.5.0", "mongodb-rag-core": "*", - "openai": "^3.2.1", "pm2": "^5.3.0", "rate-limit-mongo": "^2.3.2", "stream-json": "^1.8.0", diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts index f8d527557..95feea775 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -4,9 +4,25 @@ import { Express } from "express"; import { DEFAULT_API_PREFIX } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; import { MONGO_CHAT_MODEL } from "../../test/testConfig"; +import { ERROR_TYPE, ERROR_CODE } from "./errors"; +import { + INPUT_STRING_ERR_MSG, + INPUT_ARRAY_ERR_MSG, + METADATA_LENGTH_ERR_MSG, + TEMPERATURE_ERR_MSG, + STREAM_ERR_MSG, + MODEL_NOT_SUPPORTED_ERR_MSG, + MAX_OUTPUT_TOKENS_ERR_MSG, +} from "./createResponse"; jest.setTimeout(100000); +const badRequestError = (message: string) => ({ + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message, +}); + describe("POST /responses", () => { const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; let app: Express; @@ -321,7 +337,6 @@ describe("POST /responses", () => { }); }); - // TODO: In EAI-1126, we will need to change the error types to match the OpenAI spec describe("Invalid requests", () => { it("Should return 400 with an empty input string", async () => { const response = await request(app) @@ -335,7 +350,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError(`Path: body.input - ${INPUT_STRING_ERR_MSG}`) + ); }); it("Should return 400 with an empty message array", async () => { @@ -350,7 +367,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError(`Path: body.input - ${INPUT_ARRAY_ERR_MSG}`) + ); }); it("Should return 400 if model is not mongodb-chat-latest", async () => { @@ -365,9 +384,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ - error: "Model gpt-4o-mini is not supported.", - }); + expect(response.body.error).toEqual( + badRequestError(MODEL_NOT_SUPPORTED_ERR_MSG("gpt-4o-mini")) + ); }); it("Should return 400 if stream is not true", async () => { @@ -382,10 +401,14 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError(`Path: body.stream - ${STREAM_ERR_MSG}`) + ); }); it("Should return 400 if max_output_tokens is > 4000", async () => { + const max_output_tokens = 4001; + const response = await request(app) .post(endpointUrl) .set("X-Forwarded-For", ipAddress) @@ -394,14 +417,13 @@ describe("POST /responses", () => { model: MONGO_CHAT_MODEL, stream: true, input: "What is MongoDB?", - max_output_tokens: 4001, + max_output_tokens, }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ - error: - "Max output tokens 4001 is greater than the maximum allowed 4000.", - }); + expect(response.body.error).toEqual( + badRequestError(MAX_OUTPUT_TOKENS_ERR_MSG(max_output_tokens, 4000)) + ); }); it("Should return 400 if metadata has too many fields", async () => { @@ -421,7 +443,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError(`Path: body.metadata - ${METADATA_LENGTH_ERR_MSG}`) + ); }); it("Should return 400 if metadata value is too long", async () => { @@ -437,7 +461,11 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError( + "Path: body.metadata.key1 - String must contain at most 512 character(s)" + ) + ); }); it("Should return 400 if temperature is not 0", async () => { @@ -453,7 +481,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError(`Path: body.temperature - ${TEMPERATURE_ERR_MSG}`) + ); }); it("Should return 400 if messages contain an invalid role", async () => { @@ -470,7 +500,9 @@ describe("POST /responses", () => { ], }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError("Path: body.input - Invalid input") + ); }); it("Should return 400 if function_call has an invalid status", async () => { @@ -492,7 +524,9 @@ describe("POST /responses", () => { ], }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError("Path: body.input - Invalid input") + ); }); it("Should return 400 if function_call_output has an invalid status", async () => { @@ -513,7 +547,9 @@ describe("POST /responses", () => { ], }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError("Path: body.input - Invalid input") + ); }); it("Should return 400 with an invalid tool_choice string", async () => { @@ -529,7 +565,9 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError("Path: body.tool_choice - Invalid input") + ); }); it("Should return 400 if max_output_tokens is negative", async () => { @@ -545,7 +583,11 @@ describe("POST /responses", () => { }); expect(response.statusCode).toBe(400); - expect(response.body).toEqual({ error: "Invalid request" }); + expect(response.body.error).toEqual( + badRequestError( + "Path: body.max_output_tokens - Number must be greater than or equal to 0" + ) + ); }); }); }); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index daa41febc..db70249e3 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -1,20 +1,36 @@ import { z } from "zod"; -import { +import type { Request as ExpressRequest, Response as ExpressResponse, } from "express"; -import { RequestError, makeRequestError } from "../conversations/utils"; +import { APIError } from "mongodb-rag-core/openai"; import { SomeExpressRequest } from "../../middleware"; -import { getRequestId, sendErrorResponse } from "../../utils"; +import { getRequestId } from "../../utils"; import { GenerateResponse } from "../../processors"; +import { + makeBadRequestError, + makeInternalServerError, + generateZodErrorMessage, + sendErrorResponse, + ERROR_TYPE, +} from "./errors"; + +export const INPUT_STRING_ERR_MSG = "Input must be a non-empty string"; +export const INPUT_ARRAY_ERR_MSG = + "Input must be a string or array of messages. See https://platform.openai.com/docs/api-reference/responses/create#responses-create-input for more information."; +export const METADATA_LENGTH_ERR_MSG = "Too many metadata fields. Max 16."; +export const TEMPERATURE_ERR_MSG = "Temperature must be 0 or unset"; +export const STREAM_ERR_MSG = "'stream' must be true"; +export const MODEL_NOT_SUPPORTED_ERR_MSG = (model: string) => + `Path: body.model - ${model} is not supported.`; +export const MAX_OUTPUT_TOKENS_ERR_MSG = (input: number, max: number) => + `Path: body.max_output_tokens - ${input} is greater than the maximum allowed ${max}.`; -export const CreateResponseRequestBodySchema = z.object({ +const CreateResponseRequestBodySchema = z.object({ model: z.string(), instructions: z.string().optional(), input: z.union([ - z - .string() - .refine((input) => input.length > 0, "Input must be a non-empty string"), + z.string().refine((input) => input.length > 0, INPUT_STRING_ERR_MSG), z .array( z.union([ @@ -57,7 +73,7 @@ export const CreateResponseRequestBodySchema = z.object({ }), ]) ) - .refine((input) => input.length > 0, "Input must be a non-empty array"), + .refine((input) => input.length > 0, INPUT_ARRAY_ERR_MSG), ]), max_output_tokens: z.number().min(0).default(1000), metadata: z @@ -65,7 +81,7 @@ export const CreateResponseRequestBodySchema = z.object({ .optional() .refine( (metadata) => Object.keys(metadata ?? {}).length <= 16, - "Too many metadata fields. Max 16." + METADATA_LENGTH_ERR_MSG ), previous_response_id: z .string() @@ -76,16 +92,10 @@ export const CreateResponseRequestBodySchema = z.object({ .optional() .describe("Whether to store the response in the conversation.") .default(true), - stream: z.literal(true, { - errorMap: () => ({ message: "'stream' must be true" }), - }), + stream: z.boolean().refine((stream) => stream, STREAM_ERR_MSG), temperature: z - .union([ - z.literal(0, { - errorMap: () => ({ message: "Temperature must be 0 or unset" }), - }), - z.undefined(), - ]) + .number() + .refine((temperature) => temperature === 0, TEMPERATURE_ERR_MSG) .optional() .describe("Temperature for the model. Defaults to 0.") .default(0), @@ -120,7 +130,7 @@ export const CreateResponseRequestBodySchema = z.object({ user: z.string().optional().describe("The user ID of the user."), }); -export const CreateResponseRequest = SomeExpressRequest.merge( +const CreateResponseRequestSchema = SomeExpressRequest.merge( z.object({ headers: z.object({ "req-id": z.string(), @@ -129,7 +139,7 @@ export const CreateResponseRequest = SomeExpressRequest.merge( }) ); -export type CreateResponseRequest = z.infer; +export type CreateResponseRequest = z.infer; export interface CreateResponseRouteParams { generateResponse: GenerateResponse; @@ -138,7 +148,7 @@ export interface CreateResponseRouteParams { } export function makeCreateResponseRoute({ - // generateResponse, + generateResponse, supportedModels, maxOutputTokens, }: CreateResponseRouteParams) { @@ -147,48 +157,54 @@ export function makeCreateResponseRoute({ res: ExpressResponse<{ status: string }, any> ) => { const reqId = getRequestId(req); + const headers = req.headers as Record; + try { const { body: { model, max_output_tokens }, } = req; + // --- INPUT VALIDATION --- + const { error } = await CreateResponseRequestSchema.safeParseAsync(req); + if (error) { + throw makeBadRequestError({ + error: new Error(generateZodErrorMessage(error)), + headers, + }); + } + // --- MODEL CHECK --- if (!supportedModels.includes(model)) { - throw makeRequestError({ - httpStatus: 400, - message: `Model ${model} is not supported.`, + throw makeBadRequestError({ + error: new Error(MODEL_NOT_SUPPORTED_ERR_MSG(model)), + headers, }); } // --- MAX OUTPUT TOKENS CHECK --- if (max_output_tokens > maxOutputTokens) { - throw makeRequestError({ - httpStatus: 400, - message: `Max output tokens ${max_output_tokens} is greater than the maximum allowed ${maxOutputTokens}.`, + throw makeBadRequestError({ + error: new Error( + MAX_OUTPUT_TOKENS_ERR_MSG(max_output_tokens, maxOutputTokens) + ), + headers, }); } - // TODO: actually use this call - // generateResponse(); - // TODO: do something with maxOutputTokens (validate result length or pass to generateResponse?) + // TODO: actually implement this call + await generateResponse({} as any); return res.status(200).send({ status: "ok" }); } catch (error) { - // TODO: better error handling, in line with the Responses API - const { httpStatus, message } = - (error as Error).name === "RequestError" - ? (error as RequestError) - : makeRequestError({ - message: (error as Error).message, - stack: (error as Error).stack, - httpStatus: 500, - }); + const standardError = + (error as APIError)?.type === ERROR_TYPE + ? (error as APIError) + : makeInternalServerError({ error: error as Error, headers }); sendErrorResponse({ res, reqId, - httpStatus, - errorMessage: message, + error: standardError, }); } }; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/errors.ts b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts new file mode 100644 index 000000000..f5b6822e9 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts @@ -0,0 +1,125 @@ +import { + APIError, + BadRequestError, + InternalServerError, + NotFoundError, + RateLimitError, +} from "mongodb-rag-core/openai"; +import { logger } from "mongodb-rag-core"; +import type { Response as ExpressResponse } from "express"; +import type { ZodError } from "zod"; +import { generateErrorMessage } from "zod-error"; + +interface SendErrorResponseParams { + reqId: string; + res: ExpressResponse; + error: APIError; +} + +export const sendErrorResponse = ({ + reqId, + res, + error, +}: SendErrorResponseParams) => { + const httpStatus = error.status ?? 500; + + logger.error({ + reqId, + message: `Responding with ${httpStatus} status and error message: ${error.message}.`, + }); + + if (!res.writableEnded) { + return res.status(httpStatus).json(error); + } +}; + +// --- OPENAI ERROR CONSTANTS --- +export const ERROR_TYPE = "error"; +export enum ERROR_CODE { + INVALID_REQUEST_ERROR = "invalid_request_error", + NOT_FOUND_ERROR = "not_found_error", + RATE_LIMIT_ERROR = "rate_limit_error", + SERVER_ERROR = "server_error", +} + +// --- OPENAI ERROR WRAPPERS --- +interface MakeOpenAIErrorParams { + error: Error; + headers: Record; +} + +export const makeInternalServerError = ({ + error, + headers, +}: MakeOpenAIErrorParams): APIError => { + const message = error.message ?? "Internal server error"; + const _error = { + ...error, + type: ERROR_TYPE, + code: ERROR_CODE.SERVER_ERROR, + message, + }; + return new InternalServerError(500, _error, message, headers); +}; + +export const makeBadRequestError = ({ + error, + headers, +}: MakeOpenAIErrorParams): APIError => { + const message = error.message ?? "Bad request"; + const _error = { + ...error, + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message, + }; + return new BadRequestError(400, _error, message, headers); +}; + +export const makeNotFoundError = ({ + error, + headers, +}: MakeOpenAIErrorParams): APIError => { + const message = error.message ?? "Not found"; + const _error = { + ...error, + type: ERROR_TYPE, + code: ERROR_CODE.NOT_FOUND_ERROR, + message, + }; + return new NotFoundError(404, _error, message, headers); +}; + +export const makeRateLimitError = ({ + error, + headers, +}: MakeOpenAIErrorParams): APIError => { + const message = error.message ?? "Rate limit exceeded"; + const _error = { + ...error, + type: ERROR_TYPE, + code: ERROR_CODE.RATE_LIMIT_ERROR, + message, + }; + return new RateLimitError(429, _error, message, headers); +}; + +// --- ZOD VALIDATION ERROR MESSAGE GENERATION --- +export const generateZodErrorMessage = (error: ZodError) => { + return generateErrorMessage(error.issues, { + delimiter: { + component: " - ", + }, + path: { + enabled: true, + type: "objectNotation", + }, + code: { + enabled: false, + }, + message: { + enabled: true, + label: "", + }, + }); +}; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts index 8f6e156c0..310f84a27 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_API_PREFIX } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; import { makeTestAppConfig } from "../../test/testHelpers"; import { MONGO_CHAT_MODEL } from "../../test/testConfig"; +import { ERROR_TYPE, ERROR_CODE, makeBadRequestError } from "./errors"; jest.setTimeout(60000); @@ -47,4 +48,111 @@ describe("Responses Router", () => { expect(res.status).toBe(200); }); + + it("should return 500 when handling an unknown error", async () => { + const errorMessage = "Unknown error"; + const { app, origin } = await makeTestApp({ + ...appConfig, + responsesRouterConfig: { + createResponse: { + supportedModels: [MONGO_CHAT_MODEL], + maxOutputTokens: 4000, + generateResponse: () => Promise.reject(new Error(errorMessage)), + }, + }, + }); + + const res = await request(app) + .post(responsesEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send(validRequestBody); + + expect(res.status).toBe(500); + expect(res.body.type).toBe(ERROR_TYPE); + expect(res.body.code).toBe(ERROR_CODE.SERVER_ERROR); + expect(res.body.error).toEqual({ + type: ERROR_TYPE, + code: ERROR_CODE.SERVER_ERROR, + message: errorMessage, + }); + }); + + it("should return the openai error when service throws an openai error", async () => { + const errorMessage = "Bad request input"; + const { app, origin } = await makeTestApp({ + ...appConfig, + responsesRouterConfig: { + createResponse: { + supportedModels: [MONGO_CHAT_MODEL], + maxOutputTokens: 4000, + generateResponse: () => + Promise.reject( + makeBadRequestError({ + error: new Error(errorMessage), + headers: {}, + }) + ), + }, + }, + }); + + const res = await request(app) + .post(responsesEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send(validRequestBody); + + expect(res.status).toBe(400); + expect(res.body.type).toBe(ERROR_TYPE); + expect(res.body.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); + expect(res.body.error).toEqual({ + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message: errorMessage, + }); + }); + + test("Should apply responses router rate limit and return an openai error", async () => { + const rateLimitErrorMessage = "Error: rate limit exceeded!"; + + const { app, origin } = await makeTestApp({ + responsesRouterConfig: { + rateLimitConfig: { + routerRateLimitConfig: { + windowMs: 50000, // Big window to cover test duration + max: 1, // Only one request should be allowed + message: rateLimitErrorMessage, + }, + }, + }, + }); + + const successRes = await request(app) + .post(responsesEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send(validRequestBody); + + const rateLimitedRes = await request(app) + .post(responsesEndpoint) + .set("X-FORWARDED-FOR", ipAddress) + .set("Origin", origin) + .send(validRequestBody); + + expect(successRes.status).toBe(200); + expect(successRes.error).toBeFalsy(); + + expect(rateLimitedRes.status).toBe(429); + expect(rateLimitedRes.error).toBeTruthy(); + expect(rateLimitedRes.body.type).toBe(ERROR_TYPE); + expect(rateLimitedRes.body.code).toBe(ERROR_CODE.RATE_LIMIT_ERROR); + expect(rateLimitedRes.body.error).toEqual({ + type: ERROR_TYPE, + code: ERROR_CODE.RATE_LIMIT_ERROR, + message: rateLimitErrorMessage, + }); + expect(rateLimitedRes.body.headers["x-forwarded-for"]).toBe(ipAddress); + expect(rateLimitedRes.body.headers["origin"]).toBe(origin); + }); }); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts index 9079379b6..000614c07 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts @@ -1,12 +1,20 @@ import Router from "express-promise-router"; -import validateRequestSchema from "../../middleware/validateRequestSchema"; +import { makeCreateResponseRoute } from "./createResponse"; +import type { GenerateResponse } from "../../processors"; +import { getRequestId } from "../../utils"; import { - makeCreateResponseRoute, - CreateResponseRequest, -} from "./createResponse"; -import { GenerateResponse } from "../../processors"; + makeRateLimit, + makeSlowDown, + type RateLimitOptions, + type SlowDownOptions, +} from "../../middleware"; +import { makeRateLimitError, sendErrorResponse } from "./errors"; export interface ResponsesRouterParams { + rateLimitConfig?: { + routerRateLimitConfig?: RateLimitOptions; + routerSlowDownConfig?: SlowDownOptions; + }; createResponse: { generateResponse: GenerateResponse; supportedModels: string[]; @@ -17,17 +25,36 @@ export interface ResponsesRouterParams { /** Constructor function to make the /responses/* Express.js router. */ -export function makeResponsesRouter({ createResponse }: ResponsesRouterParams) { +export function makeResponsesRouter({ + rateLimitConfig, + createResponse, +}: ResponsesRouterParams) { const responsesRouter = Router(); - // TODO: add rate limit config + /* + Global rate limit the requests to the responsesRouter. + */ + const rateLimit = makeRateLimit({ + ...rateLimitConfig?.routerRateLimitConfig, + handler: (req, res, next, options) => { + const reqId = getRequestId(req); + const error = makeRateLimitError({ + error: new Error(options.message), + headers: req.headers as Record, + }); + return sendErrorResponse({ reqId, res, error }); + }, + }); + responsesRouter.use(rateLimit); + /* + Slow down the response to the responsesRouter after certain number + of requests in the time window. + */ + const globalSlowDown = makeSlowDown(rateLimitConfig?.routerSlowDownConfig); + responsesRouter.use(globalSlowDown); // Create Response API - responsesRouter.post( - "/", - validateRequestSchema(CreateResponseRequest), - makeCreateResponseRoute(createResponse) - ); + responsesRouter.post("/", makeCreateResponseRoute(createResponse)); return responsesRouter; } diff --git a/packages/mongodb-chatbot-server/src/test/testHelpers.ts b/packages/mongodb-chatbot-server/src/test/testHelpers.ts index 0455abcd9..5f68b1aff 100644 --- a/packages/mongodb-chatbot-server/src/test/testHelpers.ts +++ b/packages/mongodb-chatbot-server/src/test/testHelpers.ts @@ -18,6 +18,10 @@ export async function makeTestAppConfig( ...config.conversationsRouterConfig, ...(defaultConfigOverrides?.conversationsRouterConfig ?? {}), }, + responsesRouterConfig: { + ...config.responsesRouterConfig, + ...(defaultConfigOverrides?.responsesRouterConfig ?? {}), + }, }; assert(memoryDb, "memoryDb must be defined"); return { appConfig, systemPrompt, mongodb: memoryDb }; @@ -25,9 +29,10 @@ export async function makeTestAppConfig( export type PartialAppConfig = Omit< Partial, - "conversationsRouterConfig" + "conversationsRouterConfig" | "responsesRouterConfig" > & { conversationsRouterConfig?: Partial; + responsesRouterConfig?: Partial; }; export const TEST_ORIGIN = "http://localhost:5173"; From c55f53191b24cf0da7a756775c8490b4c4d6d109 Mon Sep 17 00:00:00 2001 From: Andrew Steinheiser Date: Tue, 1 Jul 2025 12:37:27 -0700 Subject: [PATCH 3/6] Handle messages for responses (#795) * cleaanup err msg constants * add failing tests to rag-core for message_id helper * add findByMessageId to conversation service * remove extra types from conversation service * add indexes to conversationsDb * add test for new getByMessageId service * remove duplicate export * add logic for getting conversation to createResponse * update configs for createResponse -- includes some cleanup * fix test for successful previous_message_id input * cleanup test variables * even cleaner tests * add logic to catch bad object ids * add more tests * more tests * last test * fix broken mock * cleanup tests * share logic for reaching maxUserMessages in a Conversation * bump --- .../src/config.ts | 10 +- .../addMessageToConversation.test.ts | 6 + .../conversations/addMessageToConversation.ts | 12 +- .../routes/responses/createResponse.test.ts | 647 ++++++++---------- .../src/routes/responses/createResponse.ts | 153 ++++- .../routes/responses/responsesRouter.test.ts | 77 +-- .../src/routes/responses/responsesRouter.ts | 3 + .../src/test/testConfig.ts | 16 +- packages/mongodb-rag-core/package.json | 1 - .../src/conversations/ConversationsService.ts | 15 + .../MongoDbConversations.test.ts | 16 + .../src/conversations/MongoDbConversations.ts | 44 +- 12 files changed, 507 insertions(+), 493 deletions(-) diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index c2192b2b6..7f9db1944 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -388,15 +388,11 @@ export const config: AppConfig = { }, responsesRouterConfig: { createResponse: { + conversations, + generateResponse, supportedModels: ["mongodb-chat-latest"], maxOutputTokens: 4000, - generateResponse: () => - Promise.resolve({ - messages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a database." }, - ], - }), + maxUserMessagesInConversation: 6, }, }, maxRequestTimeoutMs: 60000, diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts index dcf0680dc..20d1baa65 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts @@ -290,6 +290,9 @@ describe("POST /conversations/:conversationId/messages", () => { test("Should respond 500 if error with conversation service", async () => { const mockBrokenConversationsService: ConversationsService = { + async init() { + throw new Error("mock error"); + }, async create() { throw new Error("mock error"); }, @@ -302,6 +305,9 @@ describe("POST /conversations/:conversationId/messages", () => { async findById() { throw new Error("Error finding conversation"); }, + async findByMessageId() { + throw new Error("Error finding conversation by message id"); + }, async rateMessage() { throw new Error("mock error"); }, diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts index b41377863..ededb5ecd 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts @@ -31,6 +31,7 @@ import { GenerateResponse, GenerateResponseParams, } from "../../processors/GenerateResponse"; +import { hasTooManyUserMessagesInConversation } from "../responses/createResponse"; export const DEFAULT_MAX_INPUT_LENGTH = 3000; // magic number for max input size for LLM export const DEFAULT_MAX_USER_MESSAGES_IN_CONVERSATION = 7; // magic number for max messages in a conversation @@ -207,11 +208,12 @@ export function makeAddMessageToConversationRoute({ }); // --- MAX CONVERSATION LENGTH CHECK --- - const numUserMessages = conversation.messages.reduce( - (acc, message) => (message.role === "user" ? acc + 1 : acc), - 0 - ); - if (numUserMessages >= maxUserMessagesInConversation) { + if ( + hasTooManyUserMessagesInConversation( + conversation, + maxUserMessagesInConversation + ) + ) { // Omit the system prompt and assume the user always received one response per message throw makeRequestError({ httpStatus: 400, diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts index 95feea775..3f4ac49d5 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -1,19 +1,11 @@ import "dotenv/config"; import request from "supertest"; -import { Express } from "express"; -import { DEFAULT_API_PREFIX } from "../../app"; +import type { Express } from "express"; +import { DEFAULT_API_PREFIX, type AppConfig } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; -import { MONGO_CHAT_MODEL } from "../../test/testConfig"; +import { basicResponsesRequestBody } from "../../test/testConfig"; import { ERROR_TYPE, ERROR_CODE } from "./errors"; -import { - INPUT_STRING_ERR_MSG, - INPUT_ARRAY_ERR_MSG, - METADATA_LENGTH_ERR_MSG, - TEMPERATURE_ERR_MSG, - STREAM_ERR_MSG, - MODEL_NOT_SUPPORTED_ERR_MSG, - MAX_OUTPUT_TOKENS_ERR_MSG, -} from "./createResponse"; +import { ERR_MSG, type CreateResponseRequest } from "./createResponse"; jest.setTimeout(100000); @@ -26,312 +18,231 @@ const badRequestError = (message: string) => ({ describe("POST /responses", () => { const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; let app: Express; + let appConfig: AppConfig; let ipAddress: string; let origin: string; beforeEach(async () => { - ({ app, ipAddress, origin } = await makeTestApp()); + ({ app, ipAddress, origin, appConfig } = await makeTestApp()); }); + const makeCreateResponseRequest = ( + body?: Partial, + appOverride?: Express + ) => { + return request(appOverride ?? app) + .post(endpointUrl) + .set("X-Forwarded-For", ipAddress) + .set("Origin", origin) + .send({ ...basicResponsesRequestBody, ...body }); + }; + describe("Valid requests", () => { it("Should return 200 given a string input", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - }); + const response = await makeCreateResponseRequest(); expect(response.statusCode).toBe(200); }); it("Should return 200 given a message array input", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { role: "system", content: "You are a helpful assistant." }, - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a document database." }, - { role: "user", content: "What is a document database?" }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a document database." }, + { role: "user", content: "What is a document database?" }, + ], + }); expect(response.statusCode).toBe(200); }); it("Should return 200 given a valid request with instructions", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - instructions: "You are a helpful chatbot.", - }); + const response = await makeCreateResponseRequest({ + instructions: "You are a helpful chatbot.", + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with valid max_output_tokens", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - max_output_tokens: 4000, - }); + const response = await makeCreateResponseRequest({ + max_output_tokens: 4000, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with valid metadata", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - metadata: { key1: "value1", key2: "value2" }, - }); + const response = await makeCreateResponseRequest({ + metadata: { key1: "value1", key2: "value2" }, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with valid temperature", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - temperature: 0, - }); + const response = await makeCreateResponseRequest({ + temperature: 0, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with previous_response_id", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - previous_response_id: "some-id", + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + initialMessages: [{ role: "user", content: "What is MongoDB?" }], }); + const previousResponseId = conversation.messages[0].id; + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId.toString(), + }); + expect(response.statusCode).toBe(200); }); - it("Should return 200 with user", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - user: "some-user-id", + it("Should return 200 if previous_response_id is the latest message", async () => { + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + initialMessages: [ + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a document database." }, + { role: "user", content: "What is a document database?" }, + ], }); + const previousResponseId = conversation.messages[2].id; + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId.toString(), + }); + + expect(response.statusCode).toBe(200); + }); + + it("Should return 200 with user", async () => { + const response = await makeCreateResponseRequest({ + user: "some-user-id", + }); + expect(response.statusCode).toBe(200); }); it("Should return 200 with store=false", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - store: false, - }); + const response = await makeCreateResponseRequest({ + store: false, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with store=true", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - store: true, - }); + const response = await makeCreateResponseRequest({ + store: true, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with tools and tool_choice", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tools: [ - { - name: "test-tool", - description: "A tool for testing.", - parameters: { - type: "object", - properties: { - query: { type: "string" }, - }, - required: ["query"], + const response = await makeCreateResponseRequest({ + tools: [ + { + name: "test-tool", + description: "A tool for testing.", + parameters: { + type: "object", + properties: { + query: { type: "string" }, }, + required: ["query"], }, - ], - tool_choice: "auto", - }); + }, + ], + tool_choice: "auto", + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with a specific function tool_choice", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tools: [ - { - name: "test-tool", - description: "A tool for testing.", - parameters: { - type: "object", - properties: { - query: { type: "string" }, - }, - required: ["query"], + const response = await makeCreateResponseRequest({ + tools: [ + { + name: "test-tool", + description: "A tool for testing.", + parameters: { + type: "object", + properties: { + query: { type: "string" }, }, + required: ["query"], }, - ], - tool_choice: { - type: "function", - name: "test-tool", }, - }); + ], + tool_choice: { + type: "function", + name: "test-tool", + }, + }); expect(response.statusCode).toBe(200); }); it("Should return 200 given a message array with function_call", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { role: "user", content: "What is MongoDB?" }, - { - type: "function_call", - id: "call123", - name: "my_function", - arguments: `{"query": "value"}`, - status: "in_progress", - }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { role: "user", content: "What is MongoDB?" }, + { + type: "function_call", + id: "call123", + name: "my_function", + arguments: `{"query": "value"}`, + status: "in_progress", + }, + ], + }); expect(response.statusCode).toBe(200); }); it("Should return 200 given a message array with function_call_output", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { role: "user", content: "What is MongoDB?" }, - { - type: "function_call_output", - call_id: "call123", - output: `{"result": "success"}`, - status: "completed", - }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { role: "user", content: "What is MongoDB?" }, + { + type: "function_call_output", + call_id: "call123", + output: `{"result": "success"}`, + status: "completed", + }, + ], + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with tool_choice 'none'", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tool_choice: "none", - }); + const response = await makeCreateResponseRequest({ + tool_choice: "none", + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with tool_choice 'only'", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tool_choice: "only", - }); + const response = await makeCreateResponseRequest({ + tool_choice: "only", + }); expect(response.statusCode).toBe(200); }); it("Should return 200 with an empty tools array", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tools: [], - }); + const response = await makeCreateResponseRequest({ + tools: [], + }); expect(response.statusCode).toBe(200); }); @@ -339,90 +250,59 @@ describe("POST /responses", () => { describe("Invalid requests", () => { it("Should return 400 with an empty input string", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "", - }); + const response = await makeCreateResponseRequest({ + input: "", + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${INPUT_STRING_ERR_MSG}`) + badRequestError(`Path: body.input - ${ERR_MSG.INPUT_STRING}`) ); }); it("Should return 400 with an empty message array", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [], - }); + const response = await makeCreateResponseRequest({ + input: [], + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${INPUT_ARRAY_ERR_MSG}`) + badRequestError(`Path: body.input - ${ERR_MSG.INPUT_ARRAY}`) ); }); it("Should return 400 if model is not mongodb-chat-latest", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: "gpt-4o-mini", - stream: true, - input: "What is MongoDB?", - }); + const response = await makeCreateResponseRequest({ + model: "gpt-4o-mini", + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(MODEL_NOT_SUPPORTED_ERR_MSG("gpt-4o-mini")) + badRequestError(ERR_MSG.MODEL_NOT_SUPPORTED("gpt-4o-mini")) ); }); it("Should return 400 if stream is not true", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: false, - input: "What is MongoDB?", - }); + const response = await makeCreateResponseRequest({ + stream: false, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(`Path: body.stream - ${STREAM_ERR_MSG}`) + badRequestError(`Path: body.stream - ${ERR_MSG.STREAM}`) ); }); - it("Should return 400 if max_output_tokens is > 4000", async () => { + it("Should return 400 if max_output_tokens is > allowed limit", async () => { const max_output_tokens = 4001; - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - max_output_tokens, - }); + const response = await makeCreateResponseRequest({ + max_output_tokens, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(MAX_OUTPUT_TOKENS_ERR_MSG(max_output_tokens, 4000)) + badRequestError(ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, 4000)) ); }); @@ -431,34 +311,20 @@ describe("POST /responses", () => { for (let i = 0; i < 17; i++) { metadata[`key${i}`] = "value"; } - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - metadata, - }); + const response = await makeCreateResponseRequest({ + metadata, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(`Path: body.metadata - ${METADATA_LENGTH_ERR_MSG}`) + badRequestError(`Path: body.metadata - ${ERR_MSG.METADATA_LENGTH}`) ); }); it("Should return 400 if metadata value is too long", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - metadata: { key1: "a".repeat(513) }, - }); + const response = await makeCreateResponseRequest({ + metadata: { key1: "a".repeat(513) }, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( @@ -469,36 +335,24 @@ describe("POST /responses", () => { }); it("Should return 400 if temperature is not 0", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - temperature: 0.5, - }); + const response = await makeCreateResponseRequest({ + temperature: 0.5 as any, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( - badRequestError(`Path: body.temperature - ${TEMPERATURE_ERR_MSG}`) + badRequestError(`Path: body.temperature - ${ERR_MSG.TEMPERATURE}`) ); }); it("Should return 400 if messages contain an invalid role", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { role: "user", content: "What is MongoDB?" }, - { role: "invalid-role", content: "This is an invalid role." }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { role: "user", content: "What is MongoDB?" }, + { role: "invalid-role" as any, content: "This is an invalid role." }, + ], + }); + expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( badRequestError("Path: body.input - Invalid input") @@ -506,23 +360,18 @@ describe("POST /responses", () => { }); it("Should return 400 if function_call has an invalid status", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { - type: "function_call", - id: "call123", - name: "my_function", - arguments: `{"query": "value"}`, - status: "invalid_status", - }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { + type: "function_call", + id: "call123", + name: "my_function", + arguments: `{"query": "value"}`, + status: "invalid_status" as any, + }, + ], + }); + expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( badRequestError("Path: body.input - Invalid input") @@ -530,22 +379,17 @@ describe("POST /responses", () => { }); it("Should return 400 if function_call_output has an invalid status", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: [ - { - type: "function_call_output", - call_id: "call123", - output: `{"result": "success"}`, - status: "invalid_status", - }, - ], - }); + const response = await makeCreateResponseRequest({ + input: [ + { + type: "function_call_output", + call_id: "call123", + output: `{"result": "success"}`, + status: "invalid_status" as any, + }, + ], + }); + expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( badRequestError("Path: body.input - Invalid input") @@ -553,16 +397,9 @@ describe("POST /responses", () => { }); it("Should return 400 with an invalid tool_choice string", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - tool_choice: "invalid_choice", - }); + const response = await makeCreateResponseRequest({ + tool_choice: "invalid_choice" as any, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( @@ -571,16 +408,9 @@ describe("POST /responses", () => { }); it("Should return 400 if max_output_tokens is negative", async () => { - const response = await request(app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - max_output_tokens: -1, - }); + const response = await makeCreateResponseRequest({ + max_output_tokens: -1, + }); expect(response.statusCode).toBe(400); expect(response.body.error).toEqual( @@ -589,5 +419,76 @@ describe("POST /responses", () => { ) ); }); + + it("Should return 400 if previous_response_id is not a valid ObjectId", async () => { + const messageId = "some-id"; + + const response = await makeCreateResponseRequest({ + previous_response_id: messageId, + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError(ERR_MSG.INVALID_OBJECT_ID(messageId)) + ); + }); + + it("Should return 400 if previous_response_id is not found", async () => { + const messageId = "123456789012123456789012"; + + const response = await makeCreateResponseRequest({ + previous_response_id: messageId, + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError(ERR_MSG.MESSAGE_NOT_FOUND(messageId)) + ); + }); + + it("Should return 400 if previous_response_id is not the latest message", async () => { + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + initialMessages: [ + { role: "user", content: "What is MongoDB?" }, + { role: "assistant", content: "MongoDB is a document database." }, + { role: "user", content: "What is a document database?" }, + ], + }); + + const previousResponseId = conversation.messages[0].id; + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId.toString(), + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError( + ERR_MSG.MESSAGE_NOT_LATEST(previousResponseId.toString()) + ) + ); + }); + + it("Should return 400 if there are too many messages in the conversation", async () => { + const maxUserMessagesInConversation = 0; + const newApp = await makeTestApp({ + responsesRouterConfig: { + ...appConfig.responsesRouterConfig, + createResponse: { + ...appConfig.responsesRouterConfig.createResponse, + maxUserMessagesInConversation, + }, + }, + }); + + const response = await makeCreateResponseRequest({}, newApp.app); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError( + ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation) + ) + ); + }); }); }); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index db70249e3..3ab19ba47 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -3,10 +3,12 @@ import type { Request as ExpressRequest, Response as ExpressResponse, } from "express"; -import { APIError } from "mongodb-rag-core/openai"; +import { ObjectId } from "mongodb"; +import type { APIError } from "mongodb-rag-core/openai"; +import type { ConversationsService, Conversation } from "mongodb-rag-core"; import { SomeExpressRequest } from "../../middleware"; import { getRequestId } from "../../utils"; -import { GenerateResponse } from "../../processors"; +import type { GenerateResponse } from "../../processors"; import { makeBadRequestError, makeInternalServerError, @@ -15,22 +17,32 @@ import { ERROR_TYPE, } from "./errors"; -export const INPUT_STRING_ERR_MSG = "Input must be a non-empty string"; -export const INPUT_ARRAY_ERR_MSG = - "Input must be a string or array of messages. See https://platform.openai.com/docs/api-reference/responses/create#responses-create-input for more information."; -export const METADATA_LENGTH_ERR_MSG = "Too many metadata fields. Max 16."; -export const TEMPERATURE_ERR_MSG = "Temperature must be 0 or unset"; -export const STREAM_ERR_MSG = "'stream' must be true"; -export const MODEL_NOT_SUPPORTED_ERR_MSG = (model: string) => - `Path: body.model - ${model} is not supported.`; -export const MAX_OUTPUT_TOKENS_ERR_MSG = (input: number, max: number) => - `Path: body.max_output_tokens - ${input} is greater than the maximum allowed ${max}.`; +export const ERR_MSG = { + INPUT_STRING: "Input must be a non-empty string", + INPUT_ARRAY: + "Input must be a string or array of messages. See https://platform.openai.com/docs/api-reference/responses/create#responses-create-input for more information.", + METADATA_LENGTH: "Too many metadata fields. Max 16.", + TEMPERATURE: "Temperature must be 0 or unset", + STREAM: "'stream' must be true", + INVALID_OBJECT_ID: (id: string) => + `Path: body.previous_response_id - ${id} is not a valid ObjectId`, + MESSAGE_NOT_FOUND: (messageId: string) => + `Path: body.previous_response_id - Message ${messageId} not found`, + MESSAGE_NOT_LATEST: (messageId: string) => + `Path: body.previous_response_id - Message ${messageId} is not the latest message in the conversation`, + TOO_MANY_MESSAGES: (max: number) => + `Too many messages. You cannot send more than ${max} messages in this conversation.`, + MODEL_NOT_SUPPORTED: (model: string) => + `Path: body.model - ${model} is not supported.`, + MAX_OUTPUT_TOKENS: (input: number, max: number) => + `Path: body.max_output_tokens - ${input} is greater than the maximum allowed ${max}.`, +}; const CreateResponseRequestBodySchema = z.object({ model: z.string(), instructions: z.string().optional(), input: z.union([ - z.string().refine((input) => input.length > 0, INPUT_STRING_ERR_MSG), + z.string().refine((input) => input.length > 0, ERR_MSG.INPUT_STRING), z .array( z.union([ @@ -73,7 +85,7 @@ const CreateResponseRequestBodySchema = z.object({ }), ]) ) - .refine((input) => input.length > 0, INPUT_ARRAY_ERR_MSG), + .refine((input) => input.length > 0, ERR_MSG.INPUT_ARRAY), ]), max_output_tokens: z.number().min(0).default(1000), metadata: z @@ -81,7 +93,7 @@ const CreateResponseRequestBodySchema = z.object({ .optional() .refine( (metadata) => Object.keys(metadata ?? {}).length <= 16, - METADATA_LENGTH_ERR_MSG + ERR_MSG.METADATA_LENGTH ), previous_response_id: z .string() @@ -92,10 +104,10 @@ const CreateResponseRequestBodySchema = z.object({ .optional() .describe("Whether to store the response in the conversation.") .default(true), - stream: z.boolean().refine((stream) => stream, STREAM_ERR_MSG), + stream: z.boolean().refine((stream) => stream, ERR_MSG.STREAM), temperature: z .number() - .refine((temperature) => temperature === 0, TEMPERATURE_ERR_MSG) + .refine((temperature) => temperature === 0, ERR_MSG.TEMPERATURE) .optional() .describe("Temperature for the model. Defaults to 0.") .default(0), @@ -142,30 +154,32 @@ const CreateResponseRequestSchema = SomeExpressRequest.merge( export type CreateResponseRequest = z.infer; export interface CreateResponseRouteParams { + conversations: ConversationsService; generateResponse: GenerateResponse; supportedModels: string[]; maxOutputTokens: number; + maxUserMessagesInConversation: number; } export function makeCreateResponseRoute({ + conversations, generateResponse, supportedModels, maxOutputTokens, + maxUserMessagesInConversation, }: CreateResponseRouteParams) { return async ( req: ExpressRequest, - res: ExpressResponse<{ status: string }, any> + res: ExpressResponse<{ status: string }, any> // TODO: fix type ) => { const reqId = getRequestId(req); const headers = req.headers as Record; try { - const { - body: { model, max_output_tokens }, - } = req; - // --- INPUT VALIDATION --- - const { error } = await CreateResponseRequestSchema.safeParseAsync(req); + const { error, data } = await CreateResponseRequestSchema.safeParseAsync( + req + ); if (error) { throw makeBadRequestError({ error: new Error(generateZodErrorMessage(error)), @@ -173,10 +187,14 @@ export function makeCreateResponseRoute({ }); } + const { + body: { model, max_output_tokens, previous_response_id }, + } = data; + // --- MODEL CHECK --- if (!supportedModels.includes(model)) { throw makeBadRequestError({ - error: new Error(MODEL_NOT_SUPPORTED_ERR_MSG(model)), + error: new Error(ERR_MSG.MODEL_NOT_SUPPORTED(model)), headers, }); } @@ -185,7 +203,29 @@ export function makeCreateResponseRoute({ if (max_output_tokens > maxOutputTokens) { throw makeBadRequestError({ error: new Error( - MAX_OUTPUT_TOKENS_ERR_MSG(max_output_tokens, maxOutputTokens) + ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, maxOutputTokens) + ), + headers, + }); + } + + // --- LOAD CONVERSATION --- + const conversation = await loadConversationByMessageId({ + messageId: previous_response_id, + conversations, + headers, + }); + + // --- MAX CONVERSATION LENGTH CHECK --- + if ( + hasTooManyUserMessagesInConversation( + conversation, + maxUserMessagesInConversation + ) + ) { + throw makeBadRequestError({ + error: new Error( + ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation) ), headers, }); @@ -209,3 +249,66 @@ export function makeCreateResponseRoute({ } }; } + +interface LoadConversationByMessageIdParams { + messageId?: string; + conversations: ConversationsService; + headers: Record; +} + +async function loadConversationByMessageId({ + messageId, + conversations, + headers, +}: LoadConversationByMessageIdParams): Promise { + if (!messageId) { + return await conversations.create(); + } + + const conversation = await conversations.findByMessageId({ + messageId: convertToObjectId(messageId, headers), + }); + + if (!conversation) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.MESSAGE_NOT_FOUND(messageId)), + headers, + }); + } + + const latestMessage = conversation.messages[conversation.messages.length - 1]; + if (latestMessage.id.toString() !== messageId) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.MESSAGE_NOT_LATEST(messageId)), + headers, + }); + } + + return conversation; +} + +const convertToObjectId = ( + messageId: string, + headers: Record +): ObjectId => { + try { + return new ObjectId(messageId); + } catch (error) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.INVALID_OBJECT_ID(messageId)), + headers, + }); + } +}; + +// ideally this doesn't need to be exported once nothing else relies on it (addMessageToConversation for now) +export const hasTooManyUserMessagesInConversation = ( + conversation: Conversation, + maxUserMessagesInConversation: number +) => { + const numUserMessages = conversation.messages.reduce( + (acc, message) => (message.role === "user" ? acc + 1 : acc), + 0 + ); + return numUserMessages >= maxUserMessagesInConversation; +}; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts index 310f84a27..9bfb9b29e 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts @@ -1,50 +1,40 @@ +import type { Express } from "express"; import request from "supertest"; import { AppConfig } from "../../app"; import { DEFAULT_API_PREFIX } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; import { makeTestAppConfig } from "../../test/testHelpers"; -import { MONGO_CHAT_MODEL } from "../../test/testConfig"; +import { basicResponsesRequestBody } from "../../test/testConfig"; import { ERROR_TYPE, ERROR_CODE, makeBadRequestError } from "./errors"; +import { CreateResponseRequest } from "./createResponse"; jest.setTimeout(60000); describe("Responses Router", () => { const ipAddress = "127.0.0.1"; const responsesEndpoint = DEFAULT_API_PREFIX + "/responses"; - const validRequestBody = { - model: MONGO_CHAT_MODEL, - stream: true, - input: "What is MongoDB?", - }; let appConfig: AppConfig; beforeAll(async () => { ({ appConfig } = await makeTestAppConfig()); }); - it("should return 200 given a valid request", async () => { - const { app, origin } = await makeTestApp({ - ...appConfig, - responsesRouterConfig: { - createResponse: { - supportedModels: [MONGO_CHAT_MODEL], - maxOutputTokens: 4000, - generateResponse: () => - Promise.resolve({ - messages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a database." }, - ], - }), - }, - }, - }); - - const res = await request(app) + const makeCreateResponseRequest = ( + app: Express, + origin: string, + body?: Partial + ) => { + return request(app) .post(responsesEndpoint) - .set("X-FORWARDED-FOR", ipAddress) + .set("X-Forwarded-For", ipAddress) .set("Origin", origin) - .send(validRequestBody); + .send({ ...basicResponsesRequestBody, ...body }); + }; + + it("should return 200 given a valid request", async () => { + const { app, origin } = await makeTestApp(appConfig); + + const res = await makeCreateResponseRequest(app, origin); expect(res.status).toBe(200); }); @@ -54,19 +44,15 @@ describe("Responses Router", () => { const { app, origin } = await makeTestApp({ ...appConfig, responsesRouterConfig: { + ...appConfig.responsesRouterConfig, createResponse: { - supportedModels: [MONGO_CHAT_MODEL], - maxOutputTokens: 4000, + ...appConfig.responsesRouterConfig.createResponse, generateResponse: () => Promise.reject(new Error(errorMessage)), }, }, }); - const res = await request(app) - .post(responsesEndpoint) - .set("X-FORWARDED-FOR", ipAddress) - .set("Origin", origin) - .send(validRequestBody); + const res = await makeCreateResponseRequest(app, origin); expect(res.status).toBe(500); expect(res.body.type).toBe(ERROR_TYPE); @@ -83,9 +69,9 @@ describe("Responses Router", () => { const { app, origin } = await makeTestApp({ ...appConfig, responsesRouterConfig: { + ...appConfig.responsesRouterConfig, createResponse: { - supportedModels: [MONGO_CHAT_MODEL], - maxOutputTokens: 4000, + ...appConfig.responsesRouterConfig.createResponse, generateResponse: () => Promise.reject( makeBadRequestError({ @@ -97,11 +83,7 @@ describe("Responses Router", () => { }, }); - const res = await request(app) - .post(responsesEndpoint) - .set("X-FORWARDED-FOR", ipAddress) - .set("Origin", origin) - .send(validRequestBody); + const res = await makeCreateResponseRequest(app, origin); expect(res.status).toBe(400); expect(res.body.type).toBe(ERROR_TYPE); @@ -128,17 +110,8 @@ describe("Responses Router", () => { }, }); - const successRes = await request(app) - .post(responsesEndpoint) - .set("X-FORWARDED-FOR", ipAddress) - .set("Origin", origin) - .send(validRequestBody); - - const rateLimitedRes = await request(app) - .post(responsesEndpoint) - .set("X-FORWARDED-FOR", ipAddress) - .set("Origin", origin) - .send(validRequestBody); + const successRes = await makeCreateResponseRequest(app, origin); + const rateLimitedRes = await makeCreateResponseRequest(app, origin); expect(successRes.status).toBe(200); expect(successRes.error).toBeFalsy(); diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts index 000614c07..c88ecebbe 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.ts @@ -1,4 +1,5 @@ import Router from "express-promise-router"; +import type { ConversationsService } from "mongodb-rag-core"; import { makeCreateResponseRoute } from "./createResponse"; import type { GenerateResponse } from "../../processors"; import { getRequestId } from "../../utils"; @@ -16,9 +17,11 @@ export interface ResponsesRouterParams { routerSlowDownConfig?: SlowDownOptions; }; createResponse: { + conversations: ConversationsService; generateResponse: GenerateResponse; supportedModels: string[]; maxOutputTokens: number; + maxUserMessagesInConversation: number; }; } diff --git a/packages/mongodb-chatbot-server/src/test/testConfig.ts b/packages/mongodb-chatbot-server/src/test/testConfig.ts index 3f7337041..42064bcb8 100644 --- a/packages/mongodb-chatbot-server/src/test/testConfig.ts +++ b/packages/mongodb-chatbot-server/src/test/testConfig.ts @@ -173,6 +173,12 @@ export const mockGenerateResponse: GenerateResponse = async ({ export const MONGO_CHAT_MODEL = "mongodb-chat-latest"; +export const basicResponsesRequestBody = { + model: MONGO_CHAT_MODEL, + stream: true, + input: "What is MongoDB?", +}; + export async function makeDefaultConfig(): Promise { const conversations = makeMongoDbConversationsService(memoryDb); return { @@ -182,15 +188,11 @@ export async function makeDefaultConfig(): Promise { }, responsesRouterConfig: { createResponse: { + conversations, + generateResponse: mockGenerateResponse, supportedModels: [MONGO_CHAT_MODEL], maxOutputTokens: 4000, - generateResponse: () => - Promise.resolve({ - messages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a database." }, - ], - }), + maxUserMessagesInConversation: 6, }, }, maxRequestTimeoutMs: 30000, diff --git a/packages/mongodb-rag-core/package.json b/packages/mongodb-rag-core/package.json index 2b3b5e42e..720bfbcba 100644 --- a/packages/mongodb-rag-core/package.json +++ b/packages/mongodb-rag-core/package.json @@ -32,7 +32,6 @@ "./mongodb": "./build/mongodb.js", "./mongoDbMetadata": "./build/mongoDbMetadata/index.js", "./openai": "./build/openai.js", - "./aiSdk": "./build/aiSdk.js", "./braintrust": "./build/braintrust.js", "./dataSources": "./build/dataSources/index.js", "./models": "./build/models/index.js", diff --git a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts index 5bc97bfe5..a666ed132 100644 --- a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts +++ b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts @@ -213,6 +213,9 @@ export type AddManyConversationMessagesParams = { export interface FindByIdParams { _id: ObjectId; } +export interface FindByMessageIdParams { + messageId: ObjectId; +} export interface RateMessageParams { conversationId: ObjectId; messageId: ObjectId; @@ -245,6 +248,11 @@ export interface ConversationConstants { export interface ConversationsService { conversationConstants: ConversationConstants; + /** + Initialize the conversations service. + */ + init?: () => Promise; + /** Create a new {@link Conversation}. */ @@ -264,6 +272,13 @@ export interface ConversationsService { ) => Promise; findById: ({ _id }: FindByIdParams) => Promise; + /** + Find a {@link Conversation} by the id of a {@link Message} in the conversation. + */ + findByMessageId: ({ + messageId, + }: FindByMessageIdParams) => Promise; + /** Rate a {@link Message} in a {@link Conversation}. */ diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts index 6e8a3888f..8bb20e7ef 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts @@ -201,6 +201,22 @@ describe("Conversations Service", () => { }); expect(conversationInDb).toBeNull(); }); + test("should find a conversation by message id", async () => { + const conversation = await conversationsService.create({ + initialMessages: [systemPrompt], + }); + const messageId = conversation.messages[0].id; + const conversationInDb = await conversationsService.findByMessageId({ + messageId, + }); + expect(conversationInDb).toEqual(conversation); + }); + test("should return null if cannot find a conversation by message id", async () => { + const conversationInDb = await conversationsService.findByMessageId({ + messageId: new BSON.ObjectId(), + }); + expect(conversationInDb).toBeNull(); + }); test("Should rate a message", async () => { const { _id: conversationId } = await conversationsService.create({ initialMessages: [systemPrompt], diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts index ea093f2d5..4210fe144 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts @@ -4,18 +4,12 @@ import { defaultConversationConstants, ConversationsService, Conversation, - CreateConversationParams, - AddConversationMessageParams, - FindByIdParams, - RateMessageParams, Message, UserMessage, - AddManyConversationMessagesParams, - AddSomeMessageParams, AssistantMessage, SystemMessage, - CommentMessageParams, ToolMessage, + AddSomeMessageParams, } from "./ConversationsService"; /** @@ -29,7 +23,14 @@ export function makeMongoDbConversationsService( database.collection("conversations"); return { conversationConstants, - async create(params?: CreateConversationParams) { + + async init() { + await conversationsCollection.createIndex("messages.id"); + // NOTE: createdAt index is only used via the production collection + await conversationsCollection.createIndex("createdAt"); + }, + + async create(params) { const customData = params?.customData; const initialMessages = params?.initialMessages; const newConversation = { @@ -56,7 +57,7 @@ export function makeMongoDbConversationsService( return newConversation; }, - async addConversationMessage(params: AddConversationMessageParams) { + async addConversationMessage(params) { const { conversationId, message } = params; const newMessage = createMessage(message); const updateResult = await conversationsCollection.updateOne( @@ -75,9 +76,7 @@ export function makeMongoDbConversationsService( return newMessage; }, - async addManyConversationMessages( - params: AddManyConversationMessagesParams - ) { + async addManyConversationMessages(params) { const { messages, conversationId } = params; const newMessages = messages.map(createMessage); const updateResult = await conversationsCollection.updateOne( @@ -98,16 +97,19 @@ export function makeMongoDbConversationsService( return newMessages; }, - async findById({ _id }: FindByIdParams) { + async findById({ _id }) { const conversation = await conversationsCollection.findOne({ _id }); return conversation; }, - async rateMessage({ - conversationId, - messageId, - rating, - }: RateMessageParams) { + async findByMessageId({ messageId }) { + const conversation = await conversationsCollection.findOne({ + "messages.id": messageId, + }); + return conversation; + }, + + async rateMessage({ conversationId, messageId, rating }) { const updateResult = await conversationsCollection.updateOne( { _id: conversationId, @@ -129,11 +131,7 @@ export function makeMongoDbConversationsService( return true; }, - async commentMessage({ - conversationId, - messageId, - comment, - }: CommentMessageParams) { + async commentMessage({ conversationId, messageId, comment }) { const updateResult = await conversationsCollection.updateOne( { _id: conversationId, From 60ab27d46df634fbc7e81299ed13daa209fd7ce1 Mon Sep 17 00:00:00 2001 From: Andrew Steinheiser Date: Mon, 7 Jul 2025 10:01:45 -0700 Subject: [PATCH 4/6] Add storage logic to responses (#800) * skeleton for addMessagesToConversation helper * more skeleton * better name * increment * add check for previousResponse and input array * store metadata on conversation and create array of final messages * add call to save messages * add logic for checking userId changed * add test for conversation user id logic * update logic for adding messages to conversation * remove unneeded error case * remove test case * dont filter, just map * create helper for convertInputToDBMessages * update store logic * save store data on conversation, check for previous message id and no store * add case to handle mismatched conversation storage settings * cleanup logic for checking if conversation is stored * basic spy * implement tests * update * add final spys * safeParse > safeParseAsync bc we don't need to await any refines, etc. * add comment * add userId and storeMessageContent fields to Conversation * adjust logic for response api to use new convo fields * fix bug in conversation service logic * update userId check * update tests * test name tweak * test naming * add jset mock cleanup * abstract testMessageContent helper * update test helper * test for function call and outputs message storage * update tests * cleanup test --- .../routes/responses/createResponse.test.ts | 258 +++++++++++++++++- .../src/routes/responses/createResponse.ts | 160 ++++++++++- .../src/conversations/ConversationsService.ts | 6 + .../MongoDbConversations.test.ts | 25 ++ .../src/conversations/MongoDbConversations.ts | 9 +- 5 files changed, 437 insertions(+), 21 deletions(-) diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts index 3f4ac49d5..6d00cfc21 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -1,6 +1,7 @@ import "dotenv/config"; import request from "supertest"; import type { Express } from "express"; +import type { Conversation, SomeMessage } from "mongodb-rag-core"; import { DEFAULT_API_PREFIX, type AppConfig } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; import { basicResponsesRequestBody } from "../../test/testConfig"; @@ -9,12 +10,6 @@ import { ERR_MSG, type CreateResponseRequest } from "./createResponse"; jest.setTimeout(100000); -const badRequestError = (message: string) => ({ - type: ERROR_TYPE, - code: ERROR_CODE.INVALID_REQUEST_ERROR, - message, -}); - describe("POST /responses", () => { const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; let app: Express; @@ -26,10 +21,15 @@ describe("POST /responses", () => { ({ app, ipAddress, origin, appConfig } = await makeTestApp()); }); + afterEach(() => { + jest.restoreAllMocks(); + }); + const makeCreateResponseRequest = ( body?: Partial, appOverride?: Express ) => { + // TODO: update this to use the openai client return request(appOverride ?? app) .post(endpointUrl) .set("X-Forwarded-For", ipAddress) @@ -246,6 +246,161 @@ describe("POST /responses", () => { expect(response.statusCode).toBe(200); }); + + it("Should store conversation messages if `storeMessageContent: undefined` and `store: true`", async () => { + const createSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "create" + ); + const addMessagesSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "addManyConversationMessages" + ); + + const storeMessageContent = undefined; + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + storeMessageContent, + initialMessages: [{ role: "user", content: "What is MongoDB?" }], + }); + + const store = true; + const previousResponseId = conversation.messages[0].id.toString(); + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId, + store, + }); + + const createdConversation = await createSpy.mock.results[0].value; + const addedMessages = await addMessagesSpy.mock.results[0].value; + + expect(response.statusCode).toBe(200); + expect(createdConversation.storeMessageContent).toEqual( + storeMessageContent + ); + testDefaultMessageContent({ + createdConversation, + addedMessages, + store, + }); + }); + + it("Should store conversation messages when `store: true`", async () => { + const createSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "create" + ); + const addMessagesSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "addManyConversationMessages" + ); + + const store = true; + const userId = "customUserId"; + const metadata = { + customMessage1: "customMessage1", + customMessage2: "customMessage2", + }; + const response = await makeCreateResponseRequest({ + store, + metadata, + user: userId, + }); + + const createdConversation = await createSpy.mock.results[0].value; + const addedMessages = await addMessagesSpy.mock.results[0].value; + + expect(response.statusCode).toBe(200); + expect(createdConversation.storeMessageContent).toEqual(store); + testDefaultMessageContent({ + createdConversation, + addedMessages, + userId, + store, + metadata, + }); + }); + + it("Should not store conversation messages when `store: false`", async () => { + const createSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "create" + ); + const addMessagesSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "addManyConversationMessages" + ); + + const store = false; + const userId = "customUserId"; + const metadata = { + customMessage1: "customMessage1", + customMessage2: "customMessage2", + }; + const response = await makeCreateResponseRequest({ + store, + metadata, + user: userId, + }); + + const createdConversation = await createSpy.mock.results[0].value; + const addedMessages = await addMessagesSpy.mock.results[0].value; + + expect(response.statusCode).toBe(200); + expect(createdConversation.storeMessageContent).toEqual(store); + testDefaultMessageContent({ + createdConversation, + addedMessages, + userId, + store, + metadata, + }); + }); + + it("Should store function_call messages when `store: true`", async () => { + const createSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "create" + ); + const addMessagesSpy = jest.spyOn( + appConfig.conversationsRouterConfig.conversations, + "addManyConversationMessages" + ); + + const store = true; + const functionCallType = "function_call"; + const functionCallOutputType = "function_call_output"; + const response = await makeCreateResponseRequest({ + store, + input: [ + { + type: functionCallType, + id: "call123", + name: "my_function", + arguments: `{"query": "value"}`, + status: "in_progress", + }, + { + type: functionCallOutputType, + call_id: "call123", + output: `{"result": "success"}`, + status: "completed", + }, + ], + }); + + const createdConversation = await createSpy.mock.results[0].value; + const addedMessages = await addMessagesSpy.mock.results[0].value; + + expect(response.statusCode).toBe(200); + expect(createdConversation.storeMessageContent).toEqual(store); + + expect(addedMessages[0].role).toEqual("system"); + expect(addedMessages[1].role).toEqual("system"); + + expect(addedMessages[0].content).toEqual(functionCallType); + expect(addedMessages[1].content).toEqual(functionCallOutputType); + }); }); describe("Invalid requests", () => { @@ -491,4 +646,95 @@ describe("POST /responses", () => { ); }); }); + + it("Should return 400 if user id has changed since the conversation was created", async () => { + const userId1 = "user1"; + const userId2 = "user2"; + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + userId: userId1, + initialMessages: [{ role: "user", content: "What is MongoDB?" }], + }); + + const previousResponseId = conversation.messages[0].id.toString(); + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId, + user: userId2, + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError(ERR_MSG.CONVERSATION_USER_ID_CHANGED) + ); + }); + + it("Should return 400 if `store: false` and `previous_response_id` is provided", async () => { + const response = await makeCreateResponseRequest({ + previous_response_id: "123456789012123456789012", + store: false, + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError(ERR_MSG.STORE_NOT_SUPPORTED) + ); + }); + + it("Should return 400 if `store: true` and `storeMessageContent: false`", async () => { + const conversation = + await appConfig.conversationsRouterConfig.conversations.create({ + storeMessageContent: false, + initialMessages: [{ role: "user", content: "" }], + }); + + const previousResponseId = conversation.messages[0].id.toString(); + const response = await makeCreateResponseRequest({ + previous_response_id: previousResponseId, + store: true, + }); + + expect(response.statusCode).toBe(400); + expect(response.body.error).toEqual( + badRequestError(ERR_MSG.CONVERSATION_STORE_MISMATCH) + ); + }); +}); + +// --- HELPERS --- + +const badRequestError = (message: string) => ({ + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message, }); + +interface TestDefaultMessageContentParams { + createdConversation: Conversation; + addedMessages: SomeMessage[]; + store: boolean; + userId?: string; + metadata?: Record; +} + +const testDefaultMessageContent = ({ + createdConversation, + addedMessages, + store, + userId, + metadata, +}: TestDefaultMessageContentParams) => { + expect(createdConversation.userId).toEqual(userId); + + expect(addedMessages[0].role).toBe("user"); + expect(addedMessages[1].role).toEqual("user"); + expect(addedMessages[2].role).toEqual("assistant"); + + expect(addedMessages[0].content).toBe(store ? "What is MongoDB?" : ""); + expect(addedMessages[1].content).toBeFalsy(); + expect(addedMessages[2].content).toEqual(store ? "some content" : ""); + + expect(addedMessages[0].metadata).toEqual(metadata); + expect(addedMessages[1].metadata).toEqual(metadata); + expect(addedMessages[2].metadata).toEqual(metadata); + if (metadata) expect(createdConversation.customData).toEqual({ metadata }); +}; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index 3ab19ba47..aa0fec8c5 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -5,7 +5,11 @@ import type { } from "express"; import { ObjectId } from "mongodb"; import type { APIError } from "mongodb-rag-core/openai"; -import type { ConversationsService, Conversation } from "mongodb-rag-core"; +import type { + ConversationsService, + Conversation, + SomeMessage, +} from "mongodb-rag-core"; import { SomeExpressRequest } from "../../middleware"; import { getRequestId } from "../../utils"; import type { GenerateResponse } from "../../processors"; @@ -21,6 +25,8 @@ export const ERR_MSG = { INPUT_STRING: "Input must be a non-empty string", INPUT_ARRAY: "Input must be a string or array of messages. See https://platform.openai.com/docs/api-reference/responses/create#responses-create-input for more information.", + CONVERSATION_USER_ID_CHANGED: + "Path: body.user - User ID has changed since the conversation was created.", METADATA_LENGTH: "Too many metadata fields. Max 16.", TEMPERATURE: "Temperature must be 0 or unset", STREAM: "'stream' must be true", @@ -36,6 +42,10 @@ export const ERR_MSG = { `Path: body.model - ${model} is not supported.`, MAX_OUTPUT_TOKENS: (input: number, max: number) => `Path: body.max_output_tokens - ${input} is greater than the maximum allowed ${max}.`, + STORE_NOT_SUPPORTED: + "Path: body.previous_response_id | body.store - to use previous_response_id the store flag must be true", + CONVERSATION_STORE_MISMATCH: + "Path: body.previous_response_id | body.store - the conversation store flag does not match the store flag provided", }; const CreateResponseRequestBodySchema = z.object({ @@ -177,9 +187,7 @@ export function makeCreateResponseRoute({ try { // --- INPUT VALIDATION --- - const { error, data } = await CreateResponseRequestSchema.safeParseAsync( - req - ); + const { error, data } = CreateResponseRequestSchema.safeParse(req); if (error) { throw makeBadRequestError({ error: new Error(generateZodErrorMessage(error)), @@ -188,7 +196,15 @@ export function makeCreateResponseRoute({ } const { - body: { model, max_output_tokens, previous_response_id }, + body: { + model, + max_output_tokens, + previous_response_id, + store, + metadata, + user, + input, + }, } = data; // --- MODEL CHECK --- @@ -209,13 +225,32 @@ export function makeCreateResponseRoute({ }); } + // --- STORE CHECK --- + if (previous_response_id && !store) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.STORE_NOT_SUPPORTED), + headers, + }); + } + // --- LOAD CONVERSATION --- const conversation = await loadConversationByMessageId({ messageId: previous_response_id, conversations, headers, + metadata, + userId: user, + storeMessageContent: store, }); + // --- CONVERSATION USER ID CHECK --- + if (hasConversationUserIdChanged(conversation, user)) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.CONVERSATION_USER_ID_CHANGED), + headers, + }); + } + // --- MAX CONVERSATION LENGTH CHECK --- if ( hasTooManyUserMessagesInConversation( @@ -232,7 +267,17 @@ export function makeCreateResponseRoute({ } // TODO: actually implement this call - await generateResponse({} as any); + const { messages } = await generateResponse({} as any); + + // --- STORE MESSAGES IN CONVERSATION --- + await saveMessagesToConversation({ + conversations, + conversation, + store, + metadata, + input, + messages, + }); return res.status(200).send({ status: "ok" }); } catch (error) { @@ -254,15 +299,25 @@ interface LoadConversationByMessageIdParams { messageId?: string; conversations: ConversationsService; headers: Record; + metadata?: Record; + userId?: string; + storeMessageContent: boolean; } -async function loadConversationByMessageId({ +const loadConversationByMessageId = async ({ messageId, conversations, headers, -}: LoadConversationByMessageIdParams): Promise { + metadata, + userId, + storeMessageContent, +}: LoadConversationByMessageIdParams): Promise => { if (!messageId) { - return await conversations.create(); + return await conversations.create({ + userId, + storeMessageContent, + customData: { metadata }, + }); } const conversation = await conversations.findByMessageId({ @@ -276,6 +331,16 @@ async function loadConversationByMessageId({ }); } + // The default should be true because, if unset, we assume message data is stored + const shouldStoreConversation = conversation.storeMessageContent ?? true; + // this ensures that conversations will respect the store flag initially set + if (shouldStoreConversation !== storeMessageContent) { + throw makeBadRequestError({ + error: new Error(ERR_MSG.CONVERSATION_STORE_MISMATCH), + headers, + }); + } + const latestMessage = conversation.messages[conversation.messages.length - 1]; if (latestMessage.id.toString() !== messageId) { throw makeBadRequestError({ @@ -285,17 +350,17 @@ async function loadConversationByMessageId({ } return conversation; -} +}; const convertToObjectId = ( - messageId: string, + inputString: string, headers: Record ): ObjectId => { try { - return new ObjectId(messageId); + return new ObjectId(inputString); } catch (error) { throw makeBadRequestError({ - error: new Error(ERR_MSG.INVALID_OBJECT_ID(messageId)), + error: new Error(ERR_MSG.INVALID_OBJECT_ID(inputString)), headers, }); } @@ -305,10 +370,77 @@ const convertToObjectId = ( export const hasTooManyUserMessagesInConversation = ( conversation: Conversation, maxUserMessagesInConversation: number -) => { +): boolean => { const numUserMessages = conversation.messages.reduce( (acc, message) => (message.role === "user" ? acc + 1 : acc), 0 ); return numUserMessages >= maxUserMessagesInConversation; }; + +const hasConversationUserIdChanged = ( + conversation: Conversation, + userId?: string +): boolean => { + return conversation.userId !== userId; +}; + +interface AddMessagesToConversationParams { + conversations: ConversationsService; + conversation: Conversation; + store: boolean; + metadata?: Record; + input: CreateResponseRequest["body"]["input"]; + messages: Array; +} + +const saveMessagesToConversation = async ({ + conversations, + conversation, + store, + metadata, + input, + messages, +}: AddMessagesToConversationParams) => { + const messagesToAdd = [ + ...convertInputToDBMessages(input, store, metadata), + ...messages.map((message) => formatMessage(message, store, metadata)), + ]; + + await conversations.addManyConversationMessages({ + conversationId: conversation._id, + messages: messagesToAdd, + }); +}; + +const convertInputToDBMessages = ( + input: CreateResponseRequest["body"]["input"], + store: boolean, + metadata?: Record +): Array => { + if (typeof input === "string") { + return [formatMessage({ role: "user", content: input }, store, metadata)]; + } + + return input.map((message) => { + // handle function tool calls and outputs + const role = message.type === "message" ? message.role : "system"; + const content = + message.type === "message" ? message.content : message.type ?? ""; + + return formatMessage({ role, content }, store, metadata); + }); +}; + +const formatMessage = ( + message: SomeMessage, + store: boolean, + metadata?: Record +): SomeMessage => { + return { + ...message, + // store a placeholder string if we're not storing message data + content: store ? message.content : "", + metadata, + }; +}; diff --git a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts index a666ed132..2288486ad 100644 --- a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts +++ b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts @@ -162,6 +162,10 @@ export interface Conversation< createdAt: Date; /** The hostname that the request originated from. */ requestOrigin?: string; + /** The user id that the request originated from. */ + userId?: string; + /** Whether to store the message's content data. */ + storeMessageContent?: boolean; /** Custom data to include in the Conversation persisted to the database. @@ -172,6 +176,8 @@ export interface Conversation< export type CreateConversationParams = { initialMessages?: SomeMessage[]; customData?: ConversationCustomData; + userId?: string; + storeMessageContent?: boolean; }; export type AddMessageParams = Omit & { diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts index 8bb20e7ef..5a7ca43b6 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts @@ -55,6 +55,31 @@ describe("Conversations Service", () => { .findOne({ _id: conversation._id }); expect(conversationInDb).toStrictEqual(conversation); }); + test("Should create a conversation with userId", async () => { + const userId = "123"; + const conversation = await conversationsService.create({ + userId, + }); + const conversationInDb = await mongodb + .collection("conversations") + .findOne({ _id: conversation._id }); + + expect(conversationInDb).toHaveProperty("userId", userId); + }); + test("Should create a conversation with storeMessageContent", async () => { + const storeMessageContent = true; + const conversation = await conversationsService.create({ + storeMessageContent, + }); + const conversationInDb = await mongodb + .collection("conversations") + .findOne({ _id: conversation._id }); + + expect(conversationInDb).toHaveProperty( + "storeMessageContent", + storeMessageContent + ); + }); test("Should add a message to a conversation", async () => { const conversation = await conversationsService.create({ initialMessages: [systemPrompt], diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts index 4210fe144..9cf1e796d 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts @@ -33,7 +33,7 @@ export function makeMongoDbConversationsService( async create(params) { const customData = params?.customData; const initialMessages = params?.initialMessages; - const newConversation = { + const newConversation: Conversation = { _id: new ObjectId(), messages: initialMessages ? initialMessages?.map(createMessageFromOpenAIChatMessage) @@ -44,6 +44,13 @@ export function makeMongoDbConversationsService( // which we don't want. ...(customData !== undefined && { customData }), }; + if (params?.userId !== undefined) { + newConversation.userId = params.userId; + } + if (params?.storeMessageContent !== undefined) { + newConversation.storeMessageContent = params.storeMessageContent; + } + const insertResult = await conversationsCollection.insertOne( newConversation ); From b4abf6b490a8f4f179f1da1be061bdc717457cca Mon Sep 17 00:00:00 2001 From: Andrew Steinheiser Date: Wed, 16 Jul 2025 16:22:48 -0700 Subject: [PATCH 5/6] handle data streaming for new responses API (#807) * ensure stream is configurable in chatbot-server-public layer * setup data streamer helper * add skeleton data streaming to createResponse service * move StreamFunction type to chat-server, share with public * fix test * start streaming in createResponse * add stream disconnect * move in progress stream message * stream event from verifiedAnswer helper * create helper for addMessageToConversationVerifiedAnswerStream * apply stream configs to chat-public * update test to use openAI client * update input schema for openai client call to responses * add test helper for making local server * almost finish converting createResposne tests to use openai client * more tests * disconnect data streamed on error * update data stream logic for create response * update data streamer sendResponsesEvent write type * i think this still needs to be a string * mapper for streamError * export openai shim types * mostly working tests with reading the response stream from openai client * create test helper * fix test helper for reading entire stream * dont send normal http message at end (maybe need this when we support non-streaming version) * improved tests * more test improvement -- proper use of conversation service, additional conversation testing * fix test for too many messages * remove skip tests * mostly working responses tests * abstract helpers for openai client requests * use helpers in create response tests * fix tests by passing responseId * skip problematic test * skip problematic test * create baseResponseData helper * pass zod validated req body * add tests for all responses fields * remove log * abstract helper for formatOpenaiError * replace helper * await server closing properly * basic working responses tests with openai client * update rate limit test * fix testing port * update test type related to responses streaming * apply type to data streamer * cleanup shared type * fix router tests * fix router tests * update errors to be proper openai stream errors * ensure format message cleans customData as well * add comment * update tests per review * update test utils * fix test type * update openai rag-core to 5.9 * fix data streamer for responses events to be SSE compliant * cleanup responses tests * cleanup createResponse tests * cleanup error handling to match openai spec * fix tests for standard openai exceptions * cleanup * add "required" as an option for tool_choice * cleanup datastreamer test globals * add test to dataStreamer for streamResponses --- .../src/config.ts | 9 +- .../generateResponseWithSearchTool.test.ts | 5 +- .../generateResponseWithSearchTool.ts | 94 +- ...makeVerifiedAnswerGenerateResponse.test.ts | 32 +- .../makeVerifiedAnswerGenerateResponse.ts | 55 +- .../routes/responses/createResponse.test.ts | 941 ++++++++++-------- .../src/routes/responses/createResponse.ts | 161 ++- .../src/routes/responses/errors.ts | 17 +- .../routes/responses/responsesRouter.test.ts | 262 +++-- .../src/test/testConfig.ts | 1 - .../src/test/testHelpers.ts | 64 +- packages/mongodb-rag-core/package.json | 2 +- .../mongodb-rag-core/src/DataStreamer.test.ts | 43 +- packages/mongodb-rag-core/src/DataStreamer.ts | 29 +- 14 files changed, 1129 insertions(+), 586 deletions(-) diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 7f9db1944..0b0b67808 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -19,6 +19,7 @@ import { defaultCreateConversationCustomData, defaultAddMessageToConversationCustomData, makeVerifiedAnswerGenerateResponse, + addMessageToConversationVerifiedAnswerStream, } from "mongodb-chatbot-server"; import cookieParser from "cookie-parser"; import { blockGetRequests } from "./middleware/blockGetRequests"; @@ -40,7 +41,6 @@ import { import { AzureOpenAI } from "mongodb-rag-core/openai"; import { MongoClient } from "mongodb-rag-core/mongodb"; import { - ANALYZER_ENV_VARS, AZURE_OPENAI_ENV_VARS, PREPROCESSOR_ENV_VARS, TRACING_ENV_VARS, @@ -53,7 +53,10 @@ import { import { useSegmentIds } from "./middleware/useSegmentIds"; import { makeSearchTool } from "./tools/search"; import { makeMongoDbInputGuardrail } from "./processors/mongoDbInputGuardrail"; -import { makeGenerateResponseWithSearchTool } from "./processors/generateResponseWithSearchTool"; +import { + addMessageToConversationStream, + makeGenerateResponseWithSearchTool, +} from "./processors/generateResponseWithSearchTool"; import { makeBraintrustLogger } from "mongodb-rag-core/braintrust"; import { makeMongoDbScrubbedMessageStore } from "./tracing/scrubbedMessages/MongoDbScrubbedMessageStore"; import { MessageAnalysis } from "./tracing/scrubbedMessages/analyzeMessage"; @@ -231,6 +234,7 @@ export const generateResponse = wrapTraced( references: verifiedAnswer.references.map(addReferenceSourceType), }; }, + stream: addMessageToConversationVerifiedAnswerStream, onNoVerifiedAnswerFound: wrapTraced( makeGenerateResponseWithSearchTool({ languageModel, @@ -253,6 +257,7 @@ export const generateResponse = wrapTraced( searchTool: makeSearchTool(findContent), toolChoice: "auto", maxSteps: 5, + stream: addMessageToConversationStream, }), { name: "generateResponseWithSearchTool" } ), diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts index 3951d8141..723998986 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts @@ -351,18 +351,21 @@ describe("generateResponseWithSearchTool", () => { describe("streaming mode", () => { // Create a mock DataStreamer implementation const makeMockDataStreamer = () => { - const mockStreamData = jest.fn(); const mockConnect = jest.fn(); const mockDisconnect = jest.fn(); + const mockStreamData = jest.fn(); + const mockStreamResponses = jest.fn(); const mockStream = jest.fn().mockImplementation(async () => { // Process the stream and return a string result return "Hello"; }); + const dataStreamer = { connected: false, connect: mockConnect, disconnect: mockDisconnect, streamData: mockStreamData, + streamResponses: mockStreamResponses, stream: mockStream, } as DataStreamer; diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts index 074184d5d..692ba6e37 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts @@ -6,7 +6,6 @@ import { AssistantMessage, ToolMessage, } from "mongodb-rag-core"; - import { CoreAssistantMessage, CoreMessage, @@ -28,6 +27,7 @@ import { GenerateResponse, GenerateResponseReturnValue, InputGuardrailResult, + type StreamFunction, } from "mongodb-chatbot-server"; import { MongoDbSearchToolArgs, @@ -52,8 +52,59 @@ export interface GenerateResponseWithSearchToolParams { search_content: SearchTool; }>; searchTool: SearchTool; + stream?: { + onLlmNotWorking: StreamFunction<{ notWorkingMessage: string }>; + onLlmRefusal: StreamFunction<{ refusalMessage: string }>; + onReferenceLinks: StreamFunction<{ references: References }>; + onTextDelta: StreamFunction<{ delta: string }>; + }; } +export const addMessageToConversationStream: GenerateResponseWithSearchToolParams["stream"] = + { + onLlmNotWorking({ dataStreamer, notWorkingMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: notWorkingMessage, + }); + }, + onLlmRefusal({ dataStreamer, refusalMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: refusalMessage, + }); + }, + onReferenceLinks({ dataStreamer, references }) { + dataStreamer?.streamData({ + type: "references", + data: references, + }); + }, + onTextDelta({ dataStreamer, delta }) { + dataStreamer?.streamData({ + type: "delta", + data: delta, + }); + }, + }; + +// TODO: implement this +export const responsesApiStream: GenerateResponseWithSearchToolParams["stream"] = + { + onLlmNotWorking() { + throw new Error("not yet implemented"); + }, + onLlmRefusal() { + throw new Error("not yet implemented"); + }, + onReferenceLinks() { + throw new Error("not yet implemented"); + }, + onTextDelta() { + throw new Error("not yet implemented"); + }, + }; + /** Generate chatbot response using RAG and a search tool named {@link SEARCH_TOOL_NAME}. */ @@ -69,6 +120,7 @@ export function makeGenerateResponseWithSearchTool({ maxSteps = 2, searchTool, toolChoice, + stream, }: GenerateResponseWithSearchToolParams): GenerateResponse { return async function generateResponseWithSearchTool({ conversation, @@ -80,9 +132,11 @@ export function makeGenerateResponseWithSearchTool({ dataStreamer, request, }) { - if (shouldStream) { - assert(dataStreamer, "dataStreamer is required for streaming"); - } + const streamingModeActive = + shouldStream === true && + dataStreamer !== undefined && + stream !== undefined; + const userMessage: UserMessage = { role: "user", content: latestMessageText, @@ -179,10 +233,10 @@ export function makeGenerateResponseWithSearchTool({ switch (chunk.type) { case "text-delta": - if (shouldStream) { - dataStreamer?.streamData({ - data: chunk.textDelta, - type: "delta", + if (streamingModeActive) { + stream.onTextDelta({ + dataStreamer, + delta: chunk.textDelta, }); } break; @@ -202,10 +256,10 @@ export function makeGenerateResponseWithSearchTool({ // Stream references if we have any and weren't aborted if (references.length > 0 && !generationController.signal.aborted) { - if (shouldStream) { - dataStreamer?.streamData({ - data: references, - type: "references", + if (streamingModeActive) { + stream.onReferenceLinks({ + dataStreamer, + references, }); } } @@ -238,10 +292,10 @@ export function makeGenerateResponseWithSearchTool({ ...userMessageCustomData, ...guardrailResult, }; - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmRefusalMessage, + if (streamingModeActive) { + stream.onLlmRefusal({ + dataStreamer, + refusalMessage: llmRefusalMessage, }); } return handleReturnGeneration({ @@ -293,10 +347,10 @@ export function makeGenerateResponseWithSearchTool({ }); } } catch (error: unknown) { - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmNotWorkingMessage, + if (streamingModeActive) { + stream.onLlmNotWorking({ + dataStreamer, + notWorkingMessage: llmNotWorkingMessage, }); } diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts index c5618c9d2..90d005c1f 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts @@ -1,5 +1,8 @@ import { ObjectId } from "mongodb-rag-core/mongodb"; -import { makeVerifiedAnswerGenerateResponse } from "./makeVerifiedAnswerGenerateResponse"; +import { + makeVerifiedAnswerGenerateResponse, + type StreamFunction, +} from "./makeVerifiedAnswerGenerateResponse"; import { VerifiedAnswer, WithScore, DataStreamer } from "mongodb-rag-core"; import { GenerateResponseReturnValue } from "./GenerateResponse"; @@ -24,6 +27,29 @@ describe("makeVerifiedAnswerGenerateResponse", () => { }, ] satisfies GenerateResponseReturnValue["messages"]; + const streamVerifiedAnswer: StreamFunction<{ + verifiedAnswer: VerifiedAnswer; + }> = async ({ dataStreamer, verifiedAnswer }) => { + dataStreamer.streamData({ + type: "metadata", + data: { + verifiedAnswer: { + _id: verifiedAnswer._id, + created: verifiedAnswer.created, + updated: verifiedAnswer.updated, + }, + }, + }); + dataStreamer.streamData({ + type: "delta", + data: verifiedAnswer.answer, + }); + dataStreamer.streamData({ + type: "references", + data: verifiedAnswer.references, + }); + }; + // Create a mock verified answer const createMockVerifiedAnswer = (): WithScore => ({ answer: verifiedAnswerContent, @@ -55,6 +81,7 @@ describe("makeVerifiedAnswerGenerateResponse", () => { connect: jest.fn(), disconnect: jest.fn(), stream: jest.fn(), + streamResponses: jest.fn(), }); // Create base request parameters @@ -79,6 +106,9 @@ describe("makeVerifiedAnswerGenerateResponse", () => { onNoVerifiedAnswerFound: async () => ({ messages: noVerifiedAnswerFoundMessages, }), + stream: { + onVerifiedAnswerFound: streamVerifiedAnswer, + }, }); it("uses onNoVerifiedAnswerFound if no verified answer is found", async () => { diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts index 01d3be4f6..d8df30147 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts @@ -1,4 +1,8 @@ -import { VerifiedAnswer, FindVerifiedAnswerFunc } from "mongodb-rag-core"; +import { + VerifiedAnswer, + FindVerifiedAnswerFunc, + DataStreamer, +} from "mongodb-rag-core"; import { strict as assert } from "assert"; import { GenerateResponse, @@ -17,8 +21,40 @@ export interface MakeVerifiedAnswerGenerateResponseParams { onVerifiedAnswerFound?: (verifiedAnswer: VerifiedAnswer) => VerifiedAnswer; onNoVerifiedAnswerFound: GenerateResponse; + + stream?: { + onVerifiedAnswerFound: StreamFunction<{ verifiedAnswer: VerifiedAnswer }>; + }; } +export type StreamFunction = ( + params: { dataStreamer: DataStreamer } & Params +) => void; + +export const addMessageToConversationVerifiedAnswerStream: MakeVerifiedAnswerGenerateResponseParams["stream"] = + { + onVerifiedAnswerFound: ({ verifiedAnswer, dataStreamer }) => { + dataStreamer.streamData({ + type: "metadata", + data: { + verifiedAnswer: { + _id: verifiedAnswer._id, + created: verifiedAnswer.created, + updated: verifiedAnswer.updated, + }, + }, + }); + dataStreamer.streamData({ + type: "delta", + data: verifiedAnswer.answer, + }); + dataStreamer.streamData({ + type: "references", + data: verifiedAnswer.references, + }); + }, + }; + /** Searches for verified answers for the user query. If no verified answer can be found for the given query, the @@ -28,6 +64,7 @@ export const makeVerifiedAnswerGenerateResponse = ({ findVerifiedAnswer, onVerifiedAnswerFound, onNoVerifiedAnswerFound, + stream, }: MakeVerifiedAnswerGenerateResponseParams): GenerateResponse => { return async (args) => { const { latestMessageText, shouldStream, dataStreamer } = args; @@ -54,17 +91,11 @@ export const makeVerifiedAnswerGenerateResponse = ({ if (shouldStream) { assert(dataStreamer, "Must have dataStreamer if shouldStream=true"); - dataStreamer.streamData({ - type: "metadata", - data: metadata, - }); - dataStreamer.streamData({ - type: "delta", - data: answer, - }); - dataStreamer.streamData({ - type: "references", - data: references, + assert(stream, "Must have stream if shouldStream=true"); + + stream.onVerifiedAnswerFound({ + dataStreamer, + verifiedAnswer, }); } diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts index 6d00cfc21..4690a6225 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.test.ts @@ -1,154 +1,178 @@ import "dotenv/config"; -import request from "supertest"; -import type { Express } from "express"; -import type { Conversation, SomeMessage } from "mongodb-rag-core"; -import { DEFAULT_API_PREFIX, type AppConfig } from "../../app"; -import { makeTestApp } from "../../test/testHelpers"; -import { basicResponsesRequestBody } from "../../test/testConfig"; -import { ERROR_TYPE, ERROR_CODE } from "./errors"; +import type { Server } from "http"; +import { ObjectId } from "mongodb"; +import type { + Conversation, + ConversationsService, + SomeMessage, +} from "mongodb-rag-core"; +import { type AppConfig } from "../../app"; +import { + makeTestLocalServer, + makeOpenAiClient, + makeCreateResponseRequestStream, + type Stream, +} from "../../test/testHelpers"; +import { makeDefaultConfig } from "../../test/testConfig"; import { ERR_MSG, type CreateResponseRequest } from "./createResponse"; +import { ERROR_CODE, ERROR_TYPE } from "./errors"; jest.setTimeout(100000); describe("POST /responses", () => { - const endpointUrl = `${DEFAULT_API_PREFIX}/responses`; - let app: Express; let appConfig: AppConfig; + let server: Server; let ipAddress: string; let origin: string; + let conversations: ConversationsService; beforeEach(async () => { - ({ app, ipAddress, origin, appConfig } = await makeTestApp()); + appConfig = await makeDefaultConfig(); + + ({ conversations } = appConfig.responsesRouterConfig.createResponse); + + // use a unique port so this doesn't collide with other test suites + const testPort = 5200; + ({ server, ipAddress, origin } = await makeTestLocalServer( + appConfig, + testPort + )); }); - afterEach(() => { + afterEach(async () => { + server?.listening && server?.close(); jest.restoreAllMocks(); }); - const makeCreateResponseRequest = ( - body?: Partial, - appOverride?: Express + const makeClientAndRequest = ( + body?: Partial ) => { - // TODO: update this to use the openai client - return request(appOverride ?? app) - .post(endpointUrl) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ ...basicResponsesRequestBody, ...body }); + const openAiClient = makeOpenAiClient(origin, ipAddress); + return makeCreateResponseRequestStream(openAiClient, body); }; describe("Valid requests", () => { - it("Should return 200 given a string input", async () => { - const response = await makeCreateResponseRequest(); + it("Should return responses given a string input", async () => { + const stream = await makeClientAndRequest(); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody: {}, stream }); }); - it("Should return 200 given a message array input", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array input", async () => { + const requestBody: Partial = { input: [ { role: "system", content: "You are a helpful assistant." }, { role: "user", content: "What is MongoDB?" }, { role: "assistant", content: "MongoDB is a document database." }, { role: "user", content: "What is a document database?" }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a valid request with instructions", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a valid request with instructions", async () => { + const requestBody: Partial = { instructions: "You are a helpful chatbot.", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid max_output_tokens", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid max_output_tokens", async () => { + const requestBody: Partial = { max_output_tokens: 4000, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid metadata", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid metadata", async () => { + const requestBody: Partial = { metadata: { key1: "value1", key2: "value2" }, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with valid temperature", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with valid temperature", async () => { + const requestBody: Partial = { temperature: 0, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with previous_response_id", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [{ role: "user", content: "What is MongoDB?" }], - }); + it("Should return responses with previous_response_id", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[0].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 if previous_response_id is the latest message", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a document database." }, - { role: "user", content: "What is a document database?" }, - ], - }); + it("Should return responses if previous_response_id is the latest message", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + { role: "assistant", content: "Initial response!" }, + { role: "user", content: "Another message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[2].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with user", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with user", async () => { + const requestBody: Partial = { user: "some-user-id", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with store=false", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with store=false", async () => { + const requestBody: Partial = { store: false, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with store=true", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with store=true", async () => { + const requestBody: Partial = { store: true, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with tools and tool_choice", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with tools and tool_choice", async () => { + const requestBody: Partial = { tools: [ { + type: "function", + strict: true, name: "test-tool", description: "A tool for testing.", parameters: { @@ -161,15 +185,18 @@ describe("POST /responses", () => { }, ], tool_choice: "auto", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with a specific function tool_choice", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with a specific function tool_choice", async () => { + const requestBody: Partial = { tools: [ { + type: "function", + strict: true, name: "test-tool", description: "A tool for testing.", parameters: { @@ -185,30 +212,32 @@ describe("POST /responses", () => { type: "function", name: "test-tool", }, - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a message array with function_call", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array with function_call", async () => { + const requestBody: Partial = { input: [ { role: "user", content: "What is MongoDB?" }, { type: "function_call", - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "in_progress", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 given a message array with function_call_output", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses given a message array with function_call_output", async () => { + const requestBody: Partial = { input: [ { role: "user", content: "What is MongoDB?" }, { @@ -218,103 +247,91 @@ describe("POST /responses", () => { status: "completed", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with tool_choice 'none'", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with a valid tool_choice", async () => { + const requestBody: Partial = { tool_choice: "none", - }); - - expect(response.statusCode).toBe(200); - }); - - it("Should return 200 with tool_choice 'only'", async () => { - const response = await makeCreateResponseRequest({ - tool_choice: "only", - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); - it("Should return 200 with an empty tools array", async () => { - const response = await makeCreateResponseRequest({ + it("Should return responses with an empty tools array", async () => { + const requestBody: Partial = { tools: [], - }); + }; + const stream = await makeClientAndRequest(requestBody); - expect(response.statusCode).toBe(200); + await expectValidResponses({ requestBody, stream }); }); it("Should store conversation messages if `storeMessageContent: undefined` and `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const storeMessageContent = undefined; - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - storeMessageContent, - initialMessages: [{ role: "user", content: "What is MongoDB?" }], - }); + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { _id, messages } = await conversations.create({ + storeMessageContent, + initialMessages, + }); const store = true; - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, + const previous_response_id = messages.at(-1)?.id.toString(); + const requestBody: Partial = { + previous_response_id, store, - }); + }; + const stream = await makeClientAndRequest(requestBody); + + const updatedConversation = await conversations.findById({ _id }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual( + expect(updatedConversation?.storeMessageContent).toEqual( storeMessageContent ); - testDefaultMessageContent({ - createdConversation, - addedMessages, + expectDefaultMessageContent({ + initialMessages, + updatedConversation, store, }); }); it("Should store conversation messages when `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = true; const userId = "customUserId"; const metadata = { customMessage1: "customMessage1", customMessage2: "customMessage2", }; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, metadata, user: userId, - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); + + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); - testDefaultMessageContent({ - createdConversation, - addedMessages, + expect(updatedConversation.storeMessageContent).toEqual(store); + expectDefaultMessageContent({ + updatedConversation, userId, store, metadata, @@ -322,35 +339,31 @@ describe("POST /responses", () => { }); it("Should not store conversation messages when `store: false`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = false; const userId = "customUserId"; const metadata = { customMessage1: "customMessage1", customMessage2: "customMessage2", }; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, metadata, user: userId, - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); - testDefaultMessageContent({ - createdConversation, - addedMessages, + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } + + expect(updatedConversation.storeMessageContent).toEqual(store); + expectDefaultMessageContent({ + updatedConversation, userId, store, metadata, @@ -358,24 +371,15 @@ describe("POST /responses", () => { }); it("Should store function_call messages when `store: true`", async () => { - const createSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "create" - ); - const addMessagesSpy = jest.spyOn( - appConfig.conversationsRouterConfig.conversations, - "addManyConversationMessages" - ); - const store = true; const functionCallType = "function_call"; const functionCallOutputType = "function_call_output"; - const response = await makeCreateResponseRequest({ + const requestBody: Partial = { store, input: [ { type: functionCallType, - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "in_progress", @@ -387,139 +391,138 @@ describe("POST /responses", () => { status: "completed", }, ], - }); + }; + const stream = await makeClientAndRequest(requestBody); - const createdConversation = await createSpy.mock.results[0].value; - const addedMessages = await addMessagesSpy.mock.results[0].value; + const results = await expectValidResponses({ requestBody, stream }); - expect(response.statusCode).toBe(200); - expect(createdConversation.storeMessageContent).toEqual(store); + const updatedConversation = await conversations.findByMessageId({ + messageId: getMessageIdFromResults(results), + }); + if (!updatedConversation) { + return expect(updatedConversation).not.toBeNull(); + } + + expect(updatedConversation.storeMessageContent).toEqual(store); - expect(addedMessages[0].role).toEqual("system"); - expect(addedMessages[1].role).toEqual("system"); + expect(updatedConversation.messages[0].role).toEqual("system"); + expect(updatedConversation.messages[0].content).toEqual(functionCallType); - expect(addedMessages[0].content).toEqual(functionCallType); - expect(addedMessages[1].content).toEqual(functionCallOutputType); + expect(updatedConversation.messages[1].role).toEqual("system"); + expect(updatedConversation.messages[1].content).toEqual( + functionCallOutputType + ); }); }); describe("Invalid requests", () => { - it("Should return 400 with an empty input string", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if empty input string", async () => { + const stream = await makeClientAndRequest({ input: "", }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${ERR_MSG.INPUT_STRING}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.input - ${ERR_MSG.INPUT_STRING}`, + }); }); - it("Should return 400 with an empty message array", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if empty message array", async () => { + const stream = await makeClientAndRequest({ input: [], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.input - ${ERR_MSG.INPUT_ARRAY}`) - ); - }); - - it("Should return 400 if model is not mongodb-chat-latest", async () => { - const response = await makeCreateResponseRequest({ - model: "gpt-4o-mini", + await expectInvalidResponses({ + stream, + message: `Path: body.input - ${ERR_MSG.INPUT_ARRAY}`, }); - - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MODEL_NOT_SUPPORTED("gpt-4o-mini")) - ); }); - it("Should return 400 if stream is not true", async () => { - const response = await makeCreateResponseRequest({ - stream: false, + it("Should return error responses if model is not supported via config", async () => { + const invalidModel = "invalid-model"; + const stream = await makeClientAndRequest({ + model: invalidModel, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.stream - ${ERR_MSG.STREAM}`) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MODEL_NOT_SUPPORTED(invalidModel), + }); }); - it("Should return 400 if max_output_tokens is > allowed limit", async () => { + it("Should return error responses if max_output_tokens is > allowed limit", async () => { const max_output_tokens = 4001; - - const response = await makeCreateResponseRequest({ + const stream = await makeClientAndRequest({ max_output_tokens, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, 4000)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MAX_OUTPUT_TOKENS(max_output_tokens, 4000), + }); }); - it("Should return 400 if metadata has too many fields", async () => { + it("Should return error responses if metadata has too many fields", async () => { const metadata: Record = {}; for (let i = 0; i < 17; i++) { metadata[`key${i}`] = "value"; } - const response = await makeCreateResponseRequest({ + const stream = await makeClientAndRequest({ metadata, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.metadata - ${ERR_MSG.METADATA_LENGTH}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.metadata - ${ERR_MSG.METADATA_LENGTH}`, + }); }); - it("Should return 400 if metadata value is too long", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if metadata value is too long", async () => { + const stream = await makeClientAndRequest({ metadata: { key1: "a".repeat(513) }, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - "Path: body.metadata.key1 - String must contain at most 512 character(s)" - ) - ); + await expectInvalidResponses({ + stream, + message: + "Path: body.metadata.key1 - String must contain at most 512 character(s)", + }); }); - it("Should return 400 if temperature is not 0", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if temperature is not 0", async () => { + const stream = await makeClientAndRequest({ temperature: 0.5 as any, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(`Path: body.temperature - ${ERR_MSG.TEMPERATURE}`) - ); + await expectInvalidResponses({ + stream, + message: `Path: body.temperature - ${ERR_MSG.TEMPERATURE}`, + }); }); - it("Should return 400 if messages contain an invalid role", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if messages contain an invalid role", async () => { + const stream = await makeClientAndRequest({ input: [ { role: "user", content: "What is MongoDB?" }, - { role: "invalid-role" as any, content: "This is an invalid role." }, + { + role: "invalid-role" as any, + content: "This is an invalid role.", + }, ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 if function_call has an invalid status", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if function_call has an invalid status", async () => { + const stream = await makeClientAndRequest({ input: [ { type: "function_call", - id: "call123", + call_id: "call123", name: "my_function", arguments: `{"query": "value"}`, status: "invalid_status" as any, @@ -527,14 +530,14 @@ describe("POST /responses", () => { ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 if function_call_output has an invalid status", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if function_call_output has an invalid status", async () => { + const stream = await makeClientAndRequest({ input: [ { type: "function_call_output", @@ -545,196 +548,328 @@ describe("POST /responses", () => { ], }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.input - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.input - Invalid input", + }); }); - it("Should return 400 with an invalid tool_choice string", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses with an invalid tool_choice string", async () => { + const stream = await makeClientAndRequest({ tool_choice: "invalid_choice" as any, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError("Path: body.tool_choice - Invalid input") - ); + await expectInvalidResponses({ + stream, + message: "Path: body.tool_choice - Invalid input", + }); }); - it("Should return 400 if max_output_tokens is negative", async () => { - const response = await makeCreateResponseRequest({ + it("Should return error responses if max_output_tokens is negative", async () => { + const stream = await makeClientAndRequest({ max_output_tokens: -1, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - "Path: body.max_output_tokens - Number must be greater than or equal to 0" - ) - ); + await expectInvalidResponses({ + stream, + message: + "Path: body.max_output_tokens - Number must be greater than or equal to 0", + }); }); - it("Should return 400 if previous_response_id is not a valid ObjectId", async () => { - const messageId = "some-id"; - - const response = await makeCreateResponseRequest({ - previous_response_id: messageId, + it("Should return error responses if previous_response_id is not a valid ObjectId", async () => { + const previous_response_id = "some-id"; + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.INVALID_OBJECT_ID(messageId)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.INVALID_OBJECT_ID(previous_response_id), + }); }); - it("Should return 400 if previous_response_id is not found", async () => { - const messageId = "123456789012123456789012"; - - const response = await makeCreateResponseRequest({ - previous_response_id: messageId, + it("Should return error responses if previous_response_id is not found", async () => { + const previous_response_id = "123456789012123456789012"; + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.MESSAGE_NOT_FOUND(messageId)) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MESSAGE_NOT_FOUND(previous_response_id), + }); }); - it("Should return 400 if previous_response_id is not the latest message", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - initialMessages: [ - { role: "user", content: "What is MongoDB?" }, - { role: "assistant", content: "MongoDB is a document database." }, - { role: "user", content: "What is a document database?" }, - ], - }); + it("Should return error responses if previous_response_id is not the latest message", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + { role: "assistant", content: "Initial response!" }, + { role: "user", content: "Another message!" }, + ]; + const { messages } = await conversations.create({ initialMessages }); - const previousResponseId = conversation.messages[0].id; - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId.toString(), + const previous_response_id = messages[0].id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - ERR_MSG.MESSAGE_NOT_LATEST(previousResponseId.toString()) - ) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.MESSAGE_NOT_LATEST(previous_response_id), + }); }); - it("Should return 400 if there are too many messages in the conversation", async () => { - const maxUserMessagesInConversation = 0; - const newApp = await makeTestApp({ - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - maxUserMessagesInConversation, - }, - }, + it("Should return error responses if there are too many messages in the conversation", async () => { + const { maxUserMessagesInConversation } = + appConfig.responsesRouterConfig.createResponse; + + const initialMessages = Array(maxUserMessagesInConversation).fill({ + role: "user", + content: "Initial message!", }); + const { messages } = await conversations.create({ initialMessages }); - const response = await makeCreateResponseRequest({}, newApp.app); + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError( - ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation) - ) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.TOO_MANY_MESSAGES(maxUserMessagesInConversation), + }); }); - }); - it("Should return 400 if user id has changed since the conversation was created", async () => { - const userId1 = "user1"; - const userId2 = "user2"; - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ - userId: userId1, - initialMessages: [{ role: "user", content: "What is MongoDB?" }], + it("Should return error responses if user id has changed since the conversation was created", async () => { + const userId = "user1"; + const badUserId = "user2"; + + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ + userId, + initialMessages, + }); + + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + user: badUserId, }); - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, - user: userId2, + await expectInvalidResponses({ + stream, + message: ERR_MSG.CONVERSATION_USER_ID_CHANGED, + }); }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.CONVERSATION_USER_ID_CHANGED) - ); - }); + it("Should return error responses if `store: false` and `previous_response_id` is provided", async () => { + const stream = await makeClientAndRequest({ + previous_response_id: "123456789012123456789012", + store: false, + }); - it("Should return 400 if `store: false` and `previous_response_id` is provided", async () => { - const response = await makeCreateResponseRequest({ - previous_response_id: "123456789012123456789012", - store: false, + await expectInvalidResponses({ + stream, + message: ERR_MSG.STORE_NOT_SUPPORTED, + }); }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.STORE_NOT_SUPPORTED) - ); - }); - - it("Should return 400 if `store: true` and `storeMessageContent: false`", async () => { - const conversation = - await appConfig.conversationsRouterConfig.conversations.create({ + it("Should return error responses if `store: true` and `storeMessageContent: false`", async () => { + const initialMessages: Array = [ + { role: "user", content: "Initial message!" }, + ]; + const { messages } = await conversations.create({ storeMessageContent: false, - initialMessages: [{ role: "user", content: "" }], + initialMessages, }); - const previousResponseId = conversation.messages[0].id.toString(); - const response = await makeCreateResponseRequest({ - previous_response_id: previousResponseId, - store: true, - }); + const previous_response_id = messages.at(-1)?.id.toString(); + const stream = await makeClientAndRequest({ + previous_response_id, + store: true, + }); - expect(response.statusCode).toBe(400); - expect(response.body.error).toEqual( - badRequestError(ERR_MSG.CONVERSATION_STORE_MISMATCH) - ); + await expectInvalidResponses({ + stream, + message: ERR_MSG.CONVERSATION_STORE_MISMATCH, + }); + }); }); }); // --- HELPERS --- -const badRequestError = (message: string) => ({ - type: ERROR_TYPE, - code: ERROR_CODE.INVALID_REQUEST_ERROR, +const getMessageIdFromResults = (results?: Array) => { + if (!results?.length) throw new Error("No results found"); + + const messageId = results.at(-1)?.response?.id; + + if (typeof messageId !== "string") throw new Error("Message ID not found"); + + return new ObjectId(messageId); +}; + +interface ExpectInvalidResponsesParams { + stream: Stream; + message: string; +} + +const expectInvalidResponses = async ({ + stream, message, -}); +}: ExpectInvalidResponsesParams) => { + const responses: any[] = []; + try { + for await (const event of stream) { + responses.push(event); + } + + fail("expected error"); + } catch (err: any) { + expect(err.type).toBe(ERROR_TYPE); + expect(err.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); + expect(err.error.type).toBe(ERROR_TYPE); + expect(err.error.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); + expect(err.error.message).toBe(message); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(0); +}; -interface TestDefaultMessageContentParams { - createdConversation: Conversation; - addedMessages: SomeMessage[]; +interface ExpectValidResponsesParams { + stream: Stream; + requestBody: Partial; +} + +const expectValidResponses = async ({ + stream, + requestBody, +}: ExpectValidResponsesParams) => { + const responses: any[] = []; + for await (const event of stream) { + responses.push(event); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(3); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + expect(responses[2].type).toBe("response.completed"); + + responses.forEach(({ response, sequence_number }, index) => { + // basic response properties + expect(sequence_number).toBe(index); + expect(typeof response.id).toBe("string"); + expect(typeof response.created_at).toBe("number"); + expect(response.object).toBe("response"); + expect(response.error).toBeNull(); + expect(response.incomplete_details).toBeNull(); + expect(response.model).toBe("mongodb-chat-latest"); + expect(response.output_text).toBe(""); + expect(response.output).toEqual([]); + expect(response.parallel_tool_calls).toBe(true); + expect(response.temperature).toBe(0); + expect(response.stream).toBe(true); + expect(response.top_p).toBeNull(); + + // conditional upon request body properties + if (requestBody.instructions) { + expect(response.instructions).toBe(requestBody.instructions); + } else { + expect(response.instructions).toBeNull(); + } + if (requestBody.max_output_tokens) { + expect(response.max_output_tokens).toBe(requestBody.max_output_tokens); + } else { + expect(response.max_output_tokens).toBe(1000); + } + if (requestBody.previous_response_id) { + expect(response.previous_response_id).toBe( + requestBody.previous_response_id + ); + } else { + expect(response.previous_response_id).toBeNull(); + } + if (typeof requestBody.store === "boolean") { + expect(response.store).toBe(requestBody.store); + } else { + expect(response.store).toBe(true); + } + if (requestBody.tool_choice) { + expect(response.tool_choice).toEqual(requestBody.tool_choice); + } else { + expect(response.tool_choice).toBe("auto"); + } + if (requestBody.tools) { + expect(response.tools).toEqual(requestBody.tools); + } else { + expect(response.tools).toEqual([]); + } + if (requestBody.user) { + expect(response.user).toBe(requestBody.user); + } else { + expect(response.user).toBeUndefined(); + } + if (requestBody.metadata) { + expect(response.metadata).toEqual(requestBody.metadata); + } else { + expect(response.metadata).toBeNull(); + } + }); + + return responses; +}; + +interface ExpectDefaultMessageContentParams { + initialMessages?: Array; + updatedConversation: Conversation; store: boolean; userId?: string; - metadata?: Record; + metadata?: Record | null; } -const testDefaultMessageContent = ({ - createdConversation, - addedMessages, +const expectDefaultMessageContent = ({ + initialMessages, + updatedConversation, store, userId, - metadata, -}: TestDefaultMessageContentParams) => { - expect(createdConversation.userId).toEqual(userId); - - expect(addedMessages[0].role).toBe("user"); - expect(addedMessages[1].role).toEqual("user"); - expect(addedMessages[2].role).toEqual("assistant"); - - expect(addedMessages[0].content).toBe(store ? "What is MongoDB?" : ""); - expect(addedMessages[1].content).toBeFalsy(); - expect(addedMessages[2].content).toEqual(store ? "some content" : ""); - - expect(addedMessages[0].metadata).toEqual(metadata); - expect(addedMessages[1].metadata).toEqual(metadata); - expect(addedMessages[2].metadata).toEqual(metadata); - if (metadata) expect(createdConversation.customData).toEqual({ metadata }); + metadata = null, +}: ExpectDefaultMessageContentParams) => { + expect(updatedConversation.userId).toEqual(userId); + if (metadata) expect(updatedConversation.customData).toEqual({ metadata }); + + const defaultMessagesLength = 3; + const initialMessagesLength = initialMessages?.length ?? 0; + const totalMessagesLength = defaultMessagesLength + initialMessagesLength; + + const { messages } = updatedConversation; + expect(messages.length).toEqual(totalMessagesLength); + + initialMessages?.forEach((initialMessage, index) => { + expect(messages[index].role).toEqual(initialMessage.role); + expect(messages[index].content).toEqual(initialMessage.content); + expect(messages[index].metadata).toEqual(initialMessage.metadata); + expect(messages[index].customData).toEqual(initialMessage.customData); + }); + + const firstMessage = messages[initialMessagesLength]; + const secondMessage = messages[initialMessagesLength + 1]; + const thirdMessage = messages[initialMessagesLength + 2]; + + expect(firstMessage.role).toBe("user"); + expect(firstMessage.content).toBe(store ? "What is MongoDB?" : ""); + expect(firstMessage.metadata).toEqual(metadata); + + expect(secondMessage.role).toEqual("user"); + expect(secondMessage.content).toBeFalsy(); + expect(secondMessage.metadata).toEqual(metadata); + + expect(thirdMessage.role).toEqual("assistant"); + expect(thirdMessage.content).toEqual(store ? "some content" : ""); + expect(thirdMessage.metadata).toEqual(metadata); }; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index aa0fec8c5..7d2654c43 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -4,11 +4,11 @@ import type { Response as ExpressResponse, } from "express"; import { ObjectId } from "mongodb"; -import type { APIError } from "mongodb-rag-core/openai"; -import type { - ConversationsService, - Conversation, - SomeMessage, +import type { OpenAI } from "mongodb-rag-core/openai"; +import { + type ConversationsService, + type Conversation, + makeDataStreamer, } from "mongodb-rag-core"; import { SomeExpressRequest } from "../../middleware"; import { getRequestId } from "../../utils"; @@ -19,8 +19,22 @@ import { generateZodErrorMessage, sendErrorResponse, ERROR_TYPE, + type SomeOpenAIAPIError, } from "./errors"; +type StreamCreatedMessage = Omit< + OpenAI.Responses.ResponseCreatedEvent, + "sequence_number" +>; +type StreamInProgressMessage = Omit< + OpenAI.Responses.ResponseInProgressEvent, + "sequence_number" +>; +type StreamCompletedMessage = Omit< + OpenAI.Responses.ResponseCompletedEvent, + "sequence_number" +>; + export const ERR_MSG = { INPUT_STRING: "Input must be a non-empty string", INPUT_ARRAY: @@ -64,10 +78,7 @@ const CreateResponseRequestBodySchema = z.object({ // function tool call z.object({ type: z.literal("function_call"), - id: z - .string() - .optional() - .describe("Unique ID of the function tool call"), + call_id: z.string().describe("Unique ID of the function tool call"), name: z.string().describe("Name of the function tool to call"), arguments: z .string() @@ -123,11 +134,11 @@ const CreateResponseRequestBodySchema = z.object({ .default(0), tool_choice: z .union([ - z.enum(["none", "only", "auto"]), + z.enum(["none", "auto", "required"]), z .object({ - name: z.string(), type: z.literal("function"), + name: z.string(), }) .describe("Function tool choice"), ]) @@ -137,6 +148,8 @@ const CreateResponseRequestBodySchema = z.object({ tools: z .array( z.object({ + type: z.literal("function"), + strict: z.boolean(), name: z.string(), description: z.string().optional(), parameters: z @@ -184,8 +197,11 @@ export function makeCreateResponseRoute({ ) => { const reqId = getRequestId(req); const headers = req.headers as Record; + const dataStreamer = makeDataStreamer(); try { + dataStreamer.connect(res); + // --- INPUT VALIDATION --- const { error, data } = CreateResponseRequestSchema.safeParse(req); if (error) { @@ -266,6 +282,31 @@ export function makeCreateResponseRoute({ }); } + // generate responseId to use in conversation DB AND Responses API stream + const responseId = new ObjectId(); + const baseResponse = makeBaseResponseData({ + responseId, + data: data.body, + }); + + const createdMessage: StreamCreatedMessage = { + type: "response.created", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(createdMessage); + + const inProgressMessage: StreamInProgressMessage = { + type: "response.in_progress", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(inProgressMessage); + // TODO: actually implement this call const { messages } = await generateResponse({} as any); @@ -277,20 +318,39 @@ export function makeCreateResponseRoute({ metadata, input, messages, + responseId, }); - return res.status(200).send({ status: "ok" }); + const completedMessage: StreamCompletedMessage = { + type: "response.completed", + response: { + ...baseResponse, + created_at: Date.now(), + }, + }; + dataStreamer.streamResponses(completedMessage); } catch (error) { const standardError = - (error as APIError)?.type === ERROR_TYPE - ? (error as APIError) + (error as SomeOpenAIAPIError)?.type === ERROR_TYPE + ? (error as SomeOpenAIAPIError) : makeInternalServerError({ error: error as Error, headers }); - sendErrorResponse({ - res, - reqId, - error: standardError, - }); + if (dataStreamer.connected) { + dataStreamer.streamResponses({ + ...standardError, + type: ERROR_TYPE, + }); + } else { + sendErrorResponse({ + res, + reqId, + error: standardError, + }); + } + } finally { + if (dataStreamer.connected) { + dataStreamer.disconnect(); + } } }; } @@ -385,13 +445,18 @@ const hasConversationUserIdChanged = ( return conversation.userId !== userId; }; +type MessagesParam = Parameters< + ConversationsService["addManyConversationMessages"] +>[0]["messages"]; + interface AddMessagesToConversationParams { conversations: ConversationsService; conversation: Conversation; store: boolean; metadata?: Record; input: CreateResponseRequest["body"]["input"]; - messages: Array; + messages: MessagesParam; + responseId: ObjectId; } const saveMessagesToConversation = async ({ @@ -401,13 +466,19 @@ const saveMessagesToConversation = async ({ metadata, input, messages, + responseId, }: AddMessagesToConversationParams) => { const messagesToAdd = [ ...convertInputToDBMessages(input, store, metadata), ...messages.map((message) => formatMessage(message, store, metadata)), ]; + // handle setting the response id for the last message + // this corresponds to the response id in the response stream + if (messagesToAdd.length > 0) { + messagesToAdd[messagesToAdd.length - 1].id = responseId; + } - await conversations.addManyConversationMessages({ + return await conversations.addManyConversationMessages({ conversationId: conversation._id, messages: messagesToAdd, }); @@ -417,7 +488,7 @@ const convertInputToDBMessages = ( input: CreateResponseRequest["body"]["input"], store: boolean, metadata?: Record -): Array => { +): MessagesParam => { if (typeof input === "string") { return [formatMessage({ role: "user", content: input }, store, metadata)]; } @@ -433,14 +504,52 @@ const convertInputToDBMessages = ( }; const formatMessage = ( - message: SomeMessage, + message: MessagesParam[number], store: boolean, metadata?: Record -): SomeMessage => { +): MessagesParam[number] => { + // store a placeholder string if we're not storing message data + const content = store ? message.content : ""; + // handle cleaning custom data if we're not storing message data + const customData = { + ...message.customData, + query: store ? message.customData?.query : "", + reason: store ? message.customData?.reason : "", + }; + return { ...message, - // store a placeholder string if we're not storing message data - content: store ? message.content : "", + content, metadata, + customData, + }; +}; + +interface BaseResponseData { + responseId: ObjectId; + data: CreateResponseRequest["body"]; +} + +const makeBaseResponseData = ({ responseId, data }: BaseResponseData) => { + return { + id: responseId.toString(), + object: "response" as const, + error: null, + incomplete_details: null, + instructions: data.instructions ?? null, + max_output_tokens: data.max_output_tokens ?? null, + model: data.model, + output_text: "", + output: [], + parallel_tool_calls: true, + previous_response_id: data.previous_response_id ?? null, + store: data.store, + temperature: data.temperature, + stream: data.stream, + tool_choice: data.tool_choice, + tools: data.tools ?? [], + top_p: null, + user: data.user, + metadata: data.metadata ?? null, }; }; diff --git a/packages/mongodb-chatbot-server/src/routes/responses/errors.ts b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts index f5b6822e9..e4fd783c4 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/errors.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/errors.ts @@ -1,5 +1,5 @@ import { - APIError, + type APIError, BadRequestError, InternalServerError, NotFoundError, @@ -43,6 +43,13 @@ export enum ERROR_CODE { } // --- OPENAI ERROR WRAPPERS --- +export type SomeOpenAIAPIError = + | APIError + | BadRequestError + | NotFoundError + | RateLimitError + | InternalServerError; + interface MakeOpenAIErrorParams { error: Error; headers: Record; @@ -51,7 +58,7 @@ interface MakeOpenAIErrorParams { export const makeInternalServerError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Internal server error"; const _error = { ...error, @@ -65,7 +72,7 @@ export const makeInternalServerError = ({ export const makeBadRequestError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Bad request"; const _error = { ...error, @@ -79,7 +86,7 @@ export const makeBadRequestError = ({ export const makeNotFoundError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Not found"; const _error = { ...error, @@ -93,7 +100,7 @@ export const makeNotFoundError = ({ export const makeRateLimitError = ({ error, headers, -}: MakeOpenAIErrorParams): APIError => { +}: MakeOpenAIErrorParams) => { const message = error.message ?? "Rate limit exceeded"; const _error = { ...error, diff --git a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts index 9bfb9b29e..38c735e8b 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/responsesRouter.test.ts @@ -1,131 +1,189 @@ -import type { Express } from "express"; -import request from "supertest"; -import { AppConfig } from "../../app"; -import { DEFAULT_API_PREFIX } from "../../app"; -import { makeTestApp } from "../../test/testHelpers"; -import { makeTestAppConfig } from "../../test/testHelpers"; -import { basicResponsesRequestBody } from "../../test/testConfig"; -import { ERROR_TYPE, ERROR_CODE, makeBadRequestError } from "./errors"; -import { CreateResponseRequest } from "./createResponse"; +import type { Server } from "http"; +import { + makeTestLocalServer, + makeOpenAiClient, + makeCreateResponseRequestStream, + type Stream, +} from "../../test/testHelpers"; +import { makeDefaultConfig } from "../../test/testConfig"; +import { + ERROR_CODE, + ERROR_TYPE, + makeBadRequestError, + type SomeOpenAIAPIError, +} from "./errors"; jest.setTimeout(60000); describe("Responses Router", () => { - const ipAddress = "127.0.0.1"; - const responsesEndpoint = DEFAULT_API_PREFIX + "/responses"; - let appConfig: AppConfig; - - beforeAll(async () => { - ({ appConfig } = await makeTestAppConfig()); + let server: Server; + let ipAddress: string; + let origin: string; + + afterEach(async () => { + if (server?.listening) { + await new Promise((resolve) => { + server.close(() => resolve()); + }); + } + jest.clearAllMocks(); }); - const makeCreateResponseRequest = ( - app: Express, - origin: string, - body?: Partial - ) => { - return request(app) - .post(responsesEndpoint) - .set("X-Forwarded-For", ipAddress) - .set("Origin", origin) - .send({ ...basicResponsesRequestBody, ...body }); - }; - - it("should return 200 given a valid request", async () => { - const { app, origin } = await makeTestApp(appConfig); + it("should return responses given a valid request", async () => { + ({ server, ipAddress, origin } = await makeTestLocalServer()); - const res = await makeCreateResponseRequest(app, origin); + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); - expect(res.status).toBe(200); + await expectValidResponses({ stream }); }); - it("should return 500 when handling an unknown error", async () => { + it("should return an OpenAI error when handling an unknown error", async () => { const errorMessage = "Unknown error"; - const { app, origin } = await makeTestApp({ - ...appConfig, - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - generateResponse: () => Promise.reject(new Error(errorMessage)), - }, - }, - }); - const res = await makeCreateResponseRequest(app, origin); + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.createResponse.generateResponse = () => { + throw new Error(errorMessage); + }; - expect(res.status).toBe(500); - expect(res.body.type).toBe(ERROR_TYPE); - expect(res.body.code).toBe(ERROR_CODE.SERVER_ERROR); - expect(res.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.SERVER_ERROR, - message: errorMessage, - }); - }); + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); - it("should return the openai error when service throws an openai error", async () => { - const errorMessage = "Bad request input"; - const { app, origin } = await makeTestApp({ - ...appConfig, - responsesRouterConfig: { - ...appConfig.responsesRouterConfig, - createResponse: { - ...appConfig.responsesRouterConfig.createResponse, - generateResponse: () => - Promise.reject( - makeBadRequestError({ - error: new Error(errorMessage), - headers: {}, - }) - ), - }, + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); + + await expectInvalidResponses({ + stream, + error: { + type: ERROR_TYPE, + code: ERROR_CODE.SERVER_ERROR, + message: errorMessage, }, }); + }); - const res = await makeCreateResponseRequest(app, origin); + it("should return the OpenAI error when service throws an OpenAI error", async () => { + const errorMessage = "Bad request input"; - expect(res.status).toBe(400); - expect(res.body.type).toBe(ERROR_TYPE); - expect(res.body.code).toBe(ERROR_CODE.INVALID_REQUEST_ERROR); - expect(res.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.INVALID_REQUEST_ERROR, - message: errorMessage, + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.createResponse.generateResponse = () => + Promise.reject( + makeBadRequestError({ + error: new Error(errorMessage), + headers: {}, + }) + ); + + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); + + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); + + await expectInvalidResponses({ + stream, + error: { + type: ERROR_TYPE, + code: ERROR_CODE.INVALID_REQUEST_ERROR, + message: errorMessage, + }, }); }); - test("Should apply responses router rate limit and return an openai error", async () => { + it("Should return an OpenAI error when rate limit is hit", async () => { const rateLimitErrorMessage = "Error: rate limit exceeded!"; - const { app, origin } = await makeTestApp({ - responsesRouterConfig: { - rateLimitConfig: { - routerRateLimitConfig: { - windowMs: 50000, // Big window to cover test duration - max: 1, // Only one request should be allowed - message: rateLimitErrorMessage, - }, - }, + const appConfig = await makeDefaultConfig(); + appConfig.responsesRouterConfig.rateLimitConfig = { + routerRateLimitConfig: { + windowMs: 500000, // Big window to cover test duration + max: 1, // Only one request should be allowed + message: rateLimitErrorMessage, }, - }); + }; - const successRes = await makeCreateResponseRequest(app, origin); - const rateLimitedRes = await makeCreateResponseRequest(app, origin); + ({ server, ipAddress, origin } = await makeTestLocalServer(appConfig)); - expect(successRes.status).toBe(200); - expect(successRes.error).toBeFalsy(); + const openAiClient = makeOpenAiClient(origin, ipAddress); + const stream = await makeCreateResponseRequestStream(openAiClient); - expect(rateLimitedRes.status).toBe(429); - expect(rateLimitedRes.error).toBeTruthy(); - expect(rateLimitedRes.body.type).toBe(ERROR_TYPE); - expect(rateLimitedRes.body.code).toBe(ERROR_CODE.RATE_LIMIT_ERROR); - expect(rateLimitedRes.body.error).toEqual({ - type: ERROR_TYPE, - code: ERROR_CODE.RATE_LIMIT_ERROR, - message: rateLimitErrorMessage, - }); - expect(rateLimitedRes.body.headers["x-forwarded-for"]).toBe(ipAddress); - expect(rateLimitedRes.body.headers["origin"]).toBe(origin); + try { + await makeCreateResponseRequestStream(openAiClient); + + fail("expected rate limit error"); + } catch (error) { + expect((error as SomeOpenAIAPIError).status).toBe(429); + expect((error as SomeOpenAIAPIError).error).toEqual({ + type: ERROR_TYPE, + code: ERROR_CODE.RATE_LIMIT_ERROR, + message: rateLimitErrorMessage, + }); + } + + await expectValidResponses({ stream }); }); }); + +// --- HELPERS --- + +interface ExpectValidResponsesParams { + stream: Stream; +} + +const expectValidResponses = async ({ stream }: ExpectValidResponsesParams) => { + const responses: any[] = []; + for await (const event of stream) { + responses.push(event); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(3); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + expect(responses[2].type).toBe("response.completed"); + + responses.forEach(({ sequence_number, response }, index) => { + expect(sequence_number).toBe(index); + expect(typeof response.id).toBe("string"); + expect(response.object).toBe("response"); + expect(response.error).toBeNull(); + expect(response.model).toBe("mongodb-chat-latest"); + }); +}; + +interface ExpectInvalidResponsesParams { + stream: Stream; + error: { + type: string; + code: string; + message: string; + }; +} + +const expectInvalidResponses = async ({ + stream, + error, +}: ExpectInvalidResponsesParams) => { + const responses: any[] = []; + try { + for await (const event of stream) { + responses.push(event); + } + + fail("expected error"); + } catch (err: any) { + expect(err.type).toBe(error.type); + expect(err.code).toBe(error.code); + expect(err.error.type).toBe(error.type); + expect(err.error.code).toBe(error.code); + expect(err.error.message).toBe(error.message); + } + + expect(Array.isArray(responses)).toBe(true); + expect(responses.length).toBe(2); + + expect(responses[0].type).toBe("response.created"); + expect(responses[1].type).toBe("response.in_progress"); + + expect(responses[0].sequence_number).toBe(0); + expect(responses[1].sequence_number).toBe(1); +}; diff --git a/packages/mongodb-chatbot-server/src/test/testConfig.ts b/packages/mongodb-chatbot-server/src/test/testConfig.ts index 42064bcb8..b826cf63f 100644 --- a/packages/mongodb-chatbot-server/src/test/testConfig.ts +++ b/packages/mongodb-chatbot-server/src/test/testConfig.ts @@ -175,7 +175,6 @@ export const MONGO_CHAT_MODEL = "mongodb-chat-latest"; export const basicResponsesRequestBody = { model: MONGO_CHAT_MODEL, - stream: true, input: "What is MongoDB?", }; diff --git a/packages/mongodb-chatbot-server/src/test/testHelpers.ts b/packages/mongodb-chatbot-server/src/test/testHelpers.ts index 5f68b1aff..156a9af43 100644 --- a/packages/mongodb-chatbot-server/src/test/testHelpers.ts +++ b/packages/mongodb-chatbot-server/src/test/testHelpers.ts @@ -1,6 +1,13 @@ import { strict as assert } from "assert"; -import { AppConfig, makeApp } from "../app"; -import { makeDefaultConfig, memoryDb, systemPrompt } from "./testConfig"; +import { OpenAI } from "mongodb-rag-core/openai"; +import { AppConfig, DEFAULT_API_PREFIX, makeApp } from "../app"; +import { + makeDefaultConfig, + memoryDb, + systemPrompt, + basicResponsesRequestBody, +} from "./testConfig"; +import type { CreateResponseRequest } from "../routes/responses/createResponse"; export async function makeTestAppConfig( defaultConfigOverrides?: PartialAppConfig @@ -33,9 +40,11 @@ export type PartialAppConfig = Omit< > & { conversationsRouterConfig?: Partial; responsesRouterConfig?: Partial; + port?: number; }; -export const TEST_ORIGIN = "http://localhost:5173"; +export const TEST_PORT = 5173; +export const TEST_ORIGIN = `http://localhost:`; /** Helper function to quickly make an app for testing purposes. Can't be called @@ -45,7 +54,7 @@ export const TEST_ORIGIN = "http://localhost:5173"; export async function makeTestApp(defaultConfigOverrides?: PartialAppConfig) { // ip address for local host const ipAddress = "127.0.0.1"; - const origin = TEST_ORIGIN; + const origin = TEST_ORIGIN + (defaultConfigOverrides?.port ?? TEST_PORT); const { appConfig, systemPrompt, mongodb } = await makeTestAppConfig( defaultConfigOverrides @@ -63,6 +72,53 @@ export async function makeTestApp(defaultConfigOverrides?: PartialAppConfig) { }; } +export const TEST_OPENAI_API_KEY = "test-api-key"; + +/** + Helper function to quickly make a local server for testing purposes. + Builds on the other helpers for app/config stuff. + @param defaultConfigOverrides - optional overrides for default app config + */ +export const makeTestLocalServer = async ( + defaultConfigOverrides?: PartialAppConfig, + port?: number +) => { + const testAppResult = await makeTestApp({ + ...defaultConfigOverrides, + port, + }); + + const server = testAppResult.app.listen(port ?? TEST_PORT); + + return { ...testAppResult, server }; +}; + +export const makeOpenAiClient = (origin: string, ipAddress: string) => { + return new OpenAI({ + baseURL: origin + DEFAULT_API_PREFIX, + apiKey: TEST_OPENAI_API_KEY, + defaultHeaders: { + Origin: origin, + "X-Forwarded-For": ipAddress, + }, + }); +}; + +export type Stream = Awaited< + ReturnType +>; + +export const makeCreateResponseRequestStream = ( + openAiClient: OpenAI, + body?: Omit, "stream"> +) => { + return openAiClient.responses.create({ + ...basicResponsesRequestBody, + ...body, + stream: true, + }); +}; + /** Create a URL to represent a client-side route on the test origin. @param path - path to append to the origin base URL. diff --git a/packages/mongodb-rag-core/package.json b/packages/mongodb-rag-core/package.json index 720bfbcba..b6892a1ff 100644 --- a/packages/mongodb-rag-core/package.json +++ b/packages/mongodb-rag-core/package.json @@ -101,7 +101,7 @@ "ignore": "^5.3.2", "langchain": "^0.3.5", "mongodb": "^6.3.0", - "openai": "^4.95.0", + "openai": "^5.9.1", "rimraf": "^6.0.1", "simple-git": "^3.27.0", "toml": "^3.0.0", diff --git a/packages/mongodb-rag-core/src/DataStreamer.test.ts b/packages/mongodb-rag-core/src/DataStreamer.test.ts index b38b97a3d..a661cdbd2 100644 --- a/packages/mongodb-rag-core/src/DataStreamer.test.ts +++ b/packages/mongodb-rag-core/src/DataStreamer.test.ts @@ -1,16 +1,23 @@ -import { DataStreamer, makeDataStreamer } from "./DataStreamer"; +import { + DataStreamer, + makeDataStreamer, + type ResponsesStreamParams, +} from "./DataStreamer"; import { OpenAI } from "openai"; import { createResponse } from "node-mocks-http"; import { EventEmitter } from "events"; import { Response } from "express"; -let res: ReturnType & Response; -const dataStreamer = makeDataStreamer(); describe("Data Streaming", () => { + let dataStreamer: DataStreamer; + let res: ReturnType & Response; + + beforeAll(() => { + dataStreamer = makeDataStreamer(); + }); + beforeEach(() => { - res = createResponse({ - eventEmitter: EventEmitter, - }); + res = createResponse({ eventEmitter: EventEmitter }); dataStreamer.connect(res); }); @@ -79,6 +86,30 @@ describe("Data Streaming", () => { `data: {"type":"delta","data":"Once upon"}\n\ndata: {"type":"delta","data":" a time there was a"}\n\ndata: {"type":"delta","data":" very long string."}\n\n` ); }); + + it("Streams Responses API events as valid SSE events to the client", () => { + dataStreamer.streamResponses({ + type: "response.created", + id: "test1", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.in_progress", + id: "test2", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.output_text.delta", + id: "test3", + } as ResponsesStreamParams); + dataStreamer.streamResponses({ + type: "response.completed", + id: "test4", + } as ResponsesStreamParams); + + const data = res._getData(); + expect(data).toBe( + `event: response.created\ndata: {"type":"response.created","id":"test1","sequence_number":0}\n\nevent: response.in_progress\ndata: {"type":"response.in_progress","id":"test2","sequence_number":1}\n\nevent: response.output_text.delta\ndata: {"type":"response.output_text.delta","id":"test3","sequence_number":2}\n\nevent: response.completed\ndata: {"type":"response.completed","id":"test4","sequence_number":3}\n\n` + ); + }); }); function createChatCompletionWithDelta( diff --git a/packages/mongodb-rag-core/src/DataStreamer.ts b/packages/mongodb-rag-core/src/DataStreamer.ts index 423e6ec21..12d56b2bf 100644 --- a/packages/mongodb-rag-core/src/DataStreamer.ts +++ b/packages/mongodb-rag-core/src/DataStreamer.ts @@ -16,6 +16,7 @@ interface ServerSentEventDispatcher { disconnect(): void; sendData(data: Data): void; sendEvent(eventType: string, data: Data): void; + sendResponsesEvent(data: OpenAI.Responses.ResponseStreamEvent): void; } type ServerSentEventData = object | string; @@ -43,6 +44,10 @@ function makeServerSentEventDispatcher< res.write(`event: ${eventType}\n`); res.write(`data: ${JSON.stringify(data)}\n\n`); }, + sendResponsesEvent(data) { + res.write(`event: ${data.type}\n`); + res.write(`data: ${JSON.stringify(data)}\n\n`); + }, }; } @@ -53,6 +58,10 @@ interface StreamParams { type StreamEvent = { type: string; data: unknown }; +export type ResponsesStreamParams = + | Omit + | Omit; + /** Event when server streams additional message response to the client. */ @@ -122,6 +131,7 @@ export interface DataStreamer { disconnect(): void; streamData(data: SomeStreamEvent): void; stream(params: StreamParams): Promise; + streamResponses(data: ResponsesStreamParams): void; } /** @@ -130,6 +140,7 @@ export interface DataStreamer { export function makeDataStreamer(): DataStreamer { let connected = false; let sse: ServerSentEventDispatcher | undefined; + let responseSequenceNumber = 0; return { get connected() { @@ -161,7 +172,7 @@ export function makeDataStreamer(): DataStreamer { /** Streams single item of data in an event stream. */ - streamData(data: SomeStreamEvent) { + streamData(data) { if (!this.connected) { throw new Error( `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` @@ -173,7 +184,7 @@ export function makeDataStreamer(): DataStreamer { /** Streams all message events in an event stream. */ - async stream({ stream }: StreamParams) { + async stream({ stream }) { if (!this.connected) { throw new Error( `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` @@ -197,5 +208,19 @@ export function makeDataStreamer(): DataStreamer { } return streamedData; }, + + async streamResponses(data) { + if (!this.connected) { + throw new Error( + `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` + ); + } + sse?.sendResponsesEvent({ + ...data, + sequence_number: responseSequenceNumber, + } as OpenAI.Responses.ResponseStreamEvent); + + responseSequenceNumber++; + }, }; } From 8854ff7ee4509052ce14e51e7fc968d1f245f590 Mon Sep 17 00:00:00 2001 From: Ben Perlmutter Date: Mon, 21 Jul 2025 14:19:12 -0400 Subject: [PATCH 6/6] customization logic --- package-lock.json | 187 +++++++++++++++++- .../package.json | 2 + .../generateResponseWithSearchTool.ts | 21 +- .../processors/makeResponesSystemPrompt.ts | 24 +++ .../src/processors/mongoDbInputGuardrail.ts | 38 +++- .../src/processors/GenerateResponse.ts | 3 + .../src/routes/responses/createResponse.ts | 1 + 7 files changed, 264 insertions(+), 12 deletions(-) create mode 100644 packages/chatbot-server-mongodb-public/src/processors/makeResponesSystemPrompt.ts diff --git a/package-lock.json b/package-lock.json index 8efe8aaf8..fc113398f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14593,6 +14593,12 @@ "node": ">=18.0.0" } }, + "node_modules/@standard-schema/spec": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.0.0.tgz", + "integrity": "sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==", + "license": "MIT" + }, "node_modules/@stdlib/array-base-accessor-getter": { "version": "0.2.1", "license": "Apache-2.0", @@ -53199,9 +53205,10 @@ } }, "node_modules/zod": { - "version": "3.25.67", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz", - "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==", + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } @@ -53384,9 +53391,11 @@ "version": "0.21.2", "license": "Apache-2.0", "dependencies": { + "@ai-sdk/openai": "^2.0.0-beta.11", "@segment/analytics-node": "^2.2.1", "@slack/web-api": "^7.8.0", "ahocorasick": "^1.0.2", + "ai": "^5.0.0-beta.25", "common-tags": "^1.8.2", "cookie-parser": "^1.4.6", "dotenv": "^16.0.3", @@ -53433,6 +53442,52 @@ "npm": ">=8" } }, + "packages/chatbot-server-mongodb-public/node_modules/@ai-sdk/openai": { + "version": "2.0.0-beta.11", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.0-beta.11.tgz", + "integrity": "sha512-HQXUMb1V6Xr8EBYvEDwNb8ISyRqyxg2zUst7lzPb6s1nGDKJRBTfSyytNWRL9dZ9vxjM2wK34cltCfZbjaHpAA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0-beta.1", + "@ai-sdk/provider-utils": "3.0.0-beta.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "packages/chatbot-server-mongodb-public/node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider-utils": { + "version": "3.0.0-beta.5", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.0-beta.5.tgz", + "integrity": "sha512-4Dv/wiGZrvO6fI7P0yMLa4XZru0XW8LPibTObbkHBdweLUVGIze7aCfxxQeY44Uqcbl/h6/yBTkx2XmPtwf/Ow==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0-beta.1", + "@standard-schema/spec": "^1.0.0", + "eventsource-parser": "^3.0.3", + "zod-to-json-schema": "^3.24.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "packages/chatbot-server-mongodb-public/node_modules/@ai-sdk/provider": { + "version": "2.0.0-beta.1", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0-beta.1.tgz", + "integrity": "sha512-Z8SPncMtS3RsoXITmT7NVwrAq6M44dmw0DoUOYJqNNtCu8iMWuxB8Nxsoqpa0uEEy9R1V1ZThJAXTYgjTUxl3w==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, "packages/chatbot-server-mongodb-public/node_modules/@types/express-serve-static-core": { "version": "5.0.6", "dev": true, @@ -53444,6 +53499,70 @@ "@types/send": "*" } }, + "packages/chatbot-server-mongodb-public/node_modules/ai": { + "version": "5.0.0-beta.25", + "resolved": "https://registry.npmjs.org/ai/-/ai-5.0.0-beta.25.tgz", + "integrity": "sha512-pbfFqtQvz7hiDw6TwUH75CK9FgrZFBsxqbW4yW0aqluHw3nRbhf0w1u2AMiYgvWMy8Xf8TkBbMtY4vyMc4neeA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "1.0.0-beta.11", + "@ai-sdk/provider": "2.0.0-beta.1", + "@ai-sdk/provider-utils": "3.0.0-beta.5", + "@opentelemetry/api": "1.9.0" + }, + "bin": { + "ai": "dist/bin/ai.min.js" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "packages/chatbot-server-mongodb-public/node_modules/ai/node_modules/@ai-sdk/gateway": { + "version": "1.0.0-beta.11", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-1.0.0-beta.11.tgz", + "integrity": "sha512-dnRUPzSLvp3xvIx6M4FIz4ht8dfL8JkPKwH+akj10im4zbxUii3c3TQ3BJLRdx2Gq/SeljE9H0dX7PDtVyIrbQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0-beta.1", + "@ai-sdk/provider-utils": "3.0.0-beta.5" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "packages/chatbot-server-mongodb-public/node_modules/ai/node_modules/@ai-sdk/provider-utils": { + "version": "3.0.0-beta.5", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.0-beta.5.tgz", + "integrity": "sha512-4Dv/wiGZrvO6fI7P0yMLa4XZru0XW8LPibTObbkHBdweLUVGIze7aCfxxQeY44Uqcbl/h6/yBTkx2XmPtwf/Ow==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0-beta.1", + "@standard-schema/spec": "^1.0.0", + "eventsource-parser": "^3.0.3", + "zod-to-json-schema": "^3.24.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4" + } + }, + "packages/chatbot-server-mongodb-public/node_modules/eventsource-parser": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.3.tgz", + "integrity": "sha512-nVpZkTMM9rF6AQ9gPJpFsNAMt48wIzB5TQgiTLdHiuO8XEDhUgZEhqKlZWXbIzo9VmJ/HvysHqEaVeD5v9TPvA==", + "license": "MIT", + "engines": { + "node": ">=20.0.0" + } + }, "packages/datasets": { "version": "1.1.2", "license": "ISC", @@ -57157,7 +57276,7 @@ "ignore": "^5.3.2", "langchain": "^0.3.5", "mongodb": "^6.3.0", - "openai": "^4.95.0", + "openai": "^5.9.1", "rimraf": "^6.0.1", "simple-git": "^3.27.0", "toml": "^3.0.0", @@ -57775,6 +57894,45 @@ "@langchain/core": ">=0.2.26 <0.4.0" } }, + "packages/mongodb-rag-core/node_modules/@langchain/openai/node_modules/@types/node": { + "version": "18.19.120", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.120.tgz", + "integrity": "sha512-WtCGHFXnVI8WHLxDAt5TbnCM4eSE+nI0QN2NJtwzcgMhht2eNz6V9evJrk+lwC8bCY8OWV5Ym8Jz7ZEyGnKnMA==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "packages/mongodb-rag-core/node_modules/@langchain/openai/node_modules/openai": { + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + }, + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, "packages/mongodb-rag-core/node_modules/@types/jest": { "version": "26.0.24", "dev": true, @@ -58011,6 +58169,27 @@ "node": ">=16 || 14 >=14.17" } }, + "packages/mongodb-rag-core/node_modules/openai": { + "version": "5.10.1", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.10.1.tgz", + "integrity": "sha512-fq6xVfv1/gpLbsj8fArEt3b6B9jBxdhAK+VJ+bDvbUvNd+KTLlA3bnDeYZaBsGH9LUhJ1M1yXfp9sEyBLMx6eA==", + "license": "Apache-2.0", + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, "packages/mongodb-rag-core/node_modules/path-scurry": { "version": "2.0.0", "license": "BlueOak-1.0.0", diff --git a/packages/chatbot-server-mongodb-public/package.json b/packages/chatbot-server-mongodb-public/package.json index 8fc9de8ec..9acc3bc77 100644 --- a/packages/chatbot-server-mongodb-public/package.json +++ b/packages/chatbot-server-mongodb-public/package.json @@ -27,9 +27,11 @@ "generate-eval-cases": "ts-node src/eval/bin/generateEvalCasesYamlFromCSV.ts" }, "dependencies": { + "@ai-sdk/openai": "^2.0.0-beta.11", "@segment/analytics-node": "^2.2.1", "@slack/web-api": "^7.8.0", "ahocorasick": "^1.0.2", + "ai": "^5.0.0-beta.25", "common-tags": "^1.8.2", "cookie-parser": "^1.4.6", "dotenv": "^16.0.3", diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts index 692ba6e37..785f9977d 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts @@ -9,8 +9,6 @@ import { import { CoreAssistantMessage, CoreMessage, - LanguageModel, - streamText, ToolCallPart, ToolChoice, ToolSet, @@ -34,13 +32,18 @@ import { SEARCH_TOOL_NAME, SearchTool, } from "../tools/search"; +// Using v5-beta version of ai-sdk for new functionality. +// Refer to annoucement for more info https://v5.ai-sdk.dev/docs/announcing-ai-sdk-5-beta#announcing-ai-sdk-5-beta +// Specifically, the new stopWhen option is useful +import { streamText, LanguageModel, hasToolCall } from "ai"; +export type MakeSystemPrompt = (customSystemPrompt?: string) => SystemMessage; export interface GenerateResponseWithSearchToolParams { languageModel: LanguageModel; llmNotWorkingMessage: string; llmRefusalMessage: string; inputGuardrail?: InputGuardrail; - systemMessage: SystemMessage; + systemMessage: MakeSystemPrompt; filterPreviousMessages?: FilterPreviousMessages; /** Required tool for performing content search and gathering {@link References} @@ -131,6 +134,8 @@ export function makeGenerateResponseWithSearchTool({ reqId, dataStreamer, request, + customSystemPrompt, + tools, }) { const streamingModeActive = shouldStream === true && @@ -152,12 +157,13 @@ export function makeGenerateResponseWithSearchTool({ const toolSet = { [SEARCH_TOOL_NAME]: searchTool, ...(additionalTools ?? {}), + // TODO: get the client-defined tools into here } satisfies ToolSet; const generationArgs = { model: languageModel, messages: [ - systemMessage, + systemMessage(customSystemPrompt), ...filteredPreviousMessages, userMessage, ] satisfies CoreMessage[], @@ -176,6 +182,8 @@ export function makeGenerateResponseWithSearchTool({ reqId, dataStreamer, request, + tools, + customSystemPrompt, }) : undefined; @@ -201,6 +209,11 @@ export function makeGenerateResponseWithSearchTool({ const result = streamText({ ...generationArgs, abortSignal: generationController.signal, + // Something like this. refer to https://v5.ai-sdk.dev/docs/announcing-ai-sdk-5-beta#announcing-ai-sdk-5-beta + // Want to stop the generation after the client-defined tool is called + // But continue after the search tool + stopWhen: tools.map((tool) => hasToolCall(tool.name)), + onStepFinish: async ({ toolResults, toolCalls }) => { toolCalls?.forEach((toolCall) => { if (toolCall.toolName === SEARCH_TOOL_NAME) { diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeResponesSystemPrompt.ts b/packages/chatbot-server-mongodb-public/src/processors/makeResponesSystemPrompt.ts new file mode 100644 index 000000000..c6efd59cc --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/processors/makeResponesSystemPrompt.ts @@ -0,0 +1,24 @@ +import { systemPrompt } from "../systemPrompt"; +import { MakeSystemPrompt } from "./generateResponseWithSearchTool"; + +// TODO: will need to evalute this new prompt works as expected +export const makeResponsesSystemPrompt: MakeSystemPrompt = ( + customSystemPrompt +) => { + if (!customSystemPrompt) { + return systemPrompt; + } else { + return { + role: "system", + content: ` +Always adhere to the . This is your core behavior. +The developer has also provided a . Follow these instructions as well. + +${systemPrompt.content} + + +${customSystemPrompt} +`, + }; + } +}; diff --git a/packages/chatbot-server-mongodb-public/src/processors/mongoDbInputGuardrail.ts b/packages/chatbot-server-mongodb-public/src/processors/mongoDbInputGuardrail.ts index c9e565835..996011c32 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/mongoDbInputGuardrail.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/mongoDbInputGuardrail.ts @@ -206,12 +206,20 @@ ${JSON.stringify(examplePair.output, null, 2)} export interface MakeUserMessageMongoDbGuardrailParams { model: LanguageModelV1; } +// TODO: will need to evalute this new flow works as expected export const makeMongoDbInputGuardrail = ({ model, }: MakeUserMessageMongoDbGuardrailParams) => { const userMessageMongoDbGuardrail: InputGuardrail = async ({ latestMessageText, + customSystemPrompt, + tools, }) => { + const userMessage = makeInputGuardrailUserMessage({ + latestMessageText, + customSystemPrompt, + tools, + }); const { object: { type, reasoning }, } = await generateObject({ @@ -219,10 +227,7 @@ export const makeMongoDbInputGuardrail = ({ schema: UserMessageMongoDbGuardrailFunctionSchema, schemaDescription: inputGuardrailMetadata.description, schemaName: inputGuardrailMetadata.name, - messages: [ - { role: "system", content: systemPrompt }, - { role: "user" as const, content: latestMessageText }, - ], + messages: [{ role: "system", content: systemPrompt }, userMessage], mode: "json", }); const rejected = type === "irrelevant" || type === "inappropriate"; @@ -234,3 +239,28 @@ export const makeMongoDbInputGuardrail = ({ }; return userMessageMongoDbGuardrail; }; + +function makeInputGuardrailUserMessage({ + latestMessageText, + customSystemPrompt, + tools, +}: Pick< + GenerateResponseParams, + "latestMessageText" | "customSystemPrompt" | "tools" +>) { + if (!customSystemPrompt && !tools) { + return { + role: "user" as const, + content: latestMessageText, + }; + } else { + return { + role: "user" as const, + content: `${latestMessageText}${ + customSystemPrompt + ? `${customSystemPrompt}` + : "" + }${tools ? `${tools}` : ""}`, + }; + } +} diff --git a/packages/mongodb-chatbot-server/src/processors/GenerateResponse.ts b/packages/mongodb-chatbot-server/src/processors/GenerateResponse.ts index 8036319f1..b4f0ed59f 100644 --- a/packages/mongodb-chatbot-server/src/processors/GenerateResponse.ts +++ b/packages/mongodb-chatbot-server/src/processors/GenerateResponse.ts @@ -7,6 +7,7 @@ import { UserMessage, } from "mongodb-rag-core"; import { Request as ExpressRequest } from "express"; +import { OpenAI } from "mongodb-rag-core/openai"; export type ClientContext = Record; @@ -19,6 +20,8 @@ export interface GenerateResponseParams { reqId: string; conversation: Conversation; request?: ExpressRequest; + customSystemPrompt?: string; + toolDefinitions?: OpenAI.FunctionDefinition[]; } export interface GenerateResponseReturnValue { diff --git a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts index 7d2654c43..500e8fa07 100644 --- a/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts +++ b/packages/mongodb-chatbot-server/src/routes/responses/createResponse.ts @@ -308,6 +308,7 @@ export function makeCreateResponseRoute({ dataStreamer.streamResponses(inProgressMessage); // TODO: actually implement this call + // Also pass the toolDefinitions and customSystemPrompt const { messages } = await generateResponse({} as any); // --- STORE MESSAGES IN CONVERSATION ---