Skip to content

Commit 5e3d859

Browse files
authored
feat: add support for automatic embeddings for the insert many tool MCP-236 (#688)
1 parent d25a9ed commit 5e3d859

File tree

11 files changed

+1037
-65
lines changed

11 files changed

+1037
-65
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ tests/tmp
1313
coverage
1414
# Generated assets by accuracy runs
1515
.accuracy
16+
17+
.DS_Store

CONTRIBUTING.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ npm test -- path/to/test/file.test.ts
7676
npm test -- path/to/directory
7777
```
7878

79+
#### Accuracy Tests and colima
80+
81+
If you use [colima](https://github.com/abiosoft/colima) to run Docker on Mac, you will need to apply [additional configuration](https://node.testcontainers.org/supported-container-runtimes/#colima) to ensure the accuracy tests run correctly.
82+
7983
## Troubleshooting
8084

8185
### Restart Server

src/common/errors.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export enum ErrorCodes {
77
NoEmbeddingsProviderConfigured = 1_000_005,
88
AtlasVectorSearchIndexNotFound = 1_000_006,
99
AtlasVectorSearchInvalidQuery = 1_000_007,
10+
Unexpected = 1_000_008,
1011
}
1112

1213
export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {

src/common/search/embeddingsProvider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { createFetch } from "@mongodb-js/devtools-proxy-support";
77
import { z } from "zod";
88

99
type EmbeddingsInput = string;
10-
type Embeddings = number[];
10+
type Embeddings = number[] | unknown[];
1111
export type EmbeddingParameters = {
1212
inputType: "query" | "document";
1313
};

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import z from "zod";
66
import { ErrorCodes, MongoDBError } from "../errors.js";
77
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
88
import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js";
9+
import { formatUntrustedData } from "../../tools/tool.js";
910

1011
export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]);
1112
export type Similarity = z.infer<typeof similarityEnum>;
@@ -103,7 +104,34 @@ export class VectorSearchEmbeddingsManager {
103104
return definition;
104105
}
105106

106-
async findFieldsWithWrongEmbeddings(
107+
async assertFieldsHaveCorrectEmbeddings(
108+
{ database, collection }: { database: string; collection: string },
109+
documents: Document[]
110+
): Promise<void> {
111+
const embeddingValidationResults = (
112+
await Promise.all(
113+
documents.map((document) => this.findFieldsWithWrongEmbeddings({ database, collection }, document))
114+
)
115+
).flat();
116+
117+
if (embeddingValidationResults.length > 0) {
118+
const embeddingValidationMessages = embeddingValidationResults.map(
119+
(validation) =>
120+
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
121+
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
122+
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
123+
);
124+
125+
throw new MongoDBError(
126+
ErrorCodes.AtlasVectorSearchInvalidQuery,
127+
formatUntrustedData("", ...embeddingValidationMessages)
128+
.map(({ text }) => text)
129+
.join("\n")
130+
);
131+
}
132+
}
133+
134+
public async findFieldsWithWrongEmbeddings(
107135
{
108136
database,
109137
collection,
@@ -239,21 +267,34 @@ export class VectorSearchEmbeddingsManager {
239267
return undefined;
240268
}
241269

242-
public async generateEmbeddings({
270+
public async assertVectorSearchIndexExists({
243271
database,
244272
collection,
245273
path,
246-
rawValues,
247-
embeddingParameters,
248-
inputType,
249274
}: {
250275
database: string;
251276
collection: string;
252277
path: string;
278+
}): Promise<void> {
279+
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
280+
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
281+
if (!embeddingInfoForPath) {
282+
throw new MongoDBError(
283+
ErrorCodes.AtlasVectorSearchIndexNotFound,
284+
`No Vector Search index found for path "${path}" in namespace "${database}.${collection}"`
285+
);
286+
}
287+
}
288+
289+
public async generateEmbeddings({
290+
rawValues,
291+
embeddingParameters,
292+
inputType,
293+
}: {
253294
rawValues: string[];
254295
embeddingParameters: SupportedEmbeddingParameters;
255296
inputType: EmbeddingParameters["inputType"];
256-
}): Promise<unknown[]> {
297+
}): Promise<unknown[][]> {
257298
const provider = await this.atlasSearchEnabledProvider();
258299
if (!provider) {
259300
throw new MongoDBError(
@@ -275,15 +316,6 @@ export class VectorSearchEmbeddingsManager {
275316
});
276317
}
277318

278-
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
279-
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
280-
if (!embeddingInfoForPath) {
281-
throw new MongoDBError(
282-
ErrorCodes.AtlasVectorSearchIndexNotFound,
283-
`No Vector Search index found for path "${path}" in namespace "${database}.${collection}"`
284-
);
285-
}
286-
287319
return await embeddingsProvider.embed(embeddingParameters.model, rawValues, {
288320
inputType,
289321
...embeddingParameters,

src/tools/mongodb/create/insertMany.ts

Lines changed: 117 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
33
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
44
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
55
import { zEJSON } from "../../args.js";
6+
import { type Document } from "bson";
7+
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
8+
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
9+
10+
const zSupportedEmbeddingParametersWithInput = zSupportedEmbeddingParameters.extend({
11+
input: z
12+
.array(z.object({}).passthrough())
13+
.describe(
14+
"Array of objects with vector search index fields as keys (in dot notation) and the raw text values to generate embeddings for as values. The index of each object corresponds to the index of the document in the documents array."
15+
),
16+
});
617

718
export class InsertManyTool extends MongoDBToolBase {
819
public name = "insert-many";
@@ -12,46 +23,44 @@ export class InsertManyTool extends MongoDBToolBase {
1223
documents: z
1324
.array(zEJSON().describe("An individual MongoDB document"))
1425
.describe(
15-
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()"
26+
"The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()."
1627
),
28+
...(this.isFeatureEnabled("vectorSearch")
29+
? {
30+
embeddingParameters: zSupportedEmbeddingParametersWithInput
31+
.optional()
32+
.describe(
33+
"The embedding model and its parameters to use to generate embeddings for fields with vector search indexes. Note to LLM: If unsure which embedding model to use, ask the user before providing one."
34+
),
35+
}
36+
: {}),
1737
};
1838
public operationType: OperationType = "create";
1939

2040
protected async execute({
2141
database,
2242
collection,
2343
documents,
44+
embeddingParameters: providedEmbeddingParameters,
2445
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
2546
const provider = await this.ensureConnected();
2647

27-
const embeddingValidations = new Set(
28-
...(await Promise.all(
29-
documents.flatMap((document) =>
30-
this.session.vectorSearchEmbeddingsManager.findFieldsWithWrongEmbeddings(
31-
{ database, collection },
32-
document
33-
)
34-
)
35-
))
36-
);
48+
const embeddingParameters = this.isFeatureEnabled("vectorSearch")
49+
? (providedEmbeddingParameters as z.infer<typeof zSupportedEmbeddingParametersWithInput>)
50+
: undefined;
3751

38-
if (embeddingValidations.size > 0) {
39-
// tell the LLM what happened
40-
const embeddingValidationMessages = [...embeddingValidations].map(
41-
(validation) =>
42-
`- Field ${validation.path} is an embedding with ${validation.expectedNumDimensions} dimensions and ${validation.expectedQuantization}` +
43-
` quantization, and the provided value is not compatible. Actual dimensions: ${validation.actualNumDimensions}, ` +
44-
`actual quantization: ${validation.actualQuantization}. Error: ${validation.error}`
45-
);
46-
47-
return {
48-
content: formatUntrustedData(
49-
"There were errors when inserting documents. No document was inserted.",
50-
...embeddingValidationMessages
51-
),
52-
isError: true,
53-
};
54-
}
52+
// Process documents to replace raw string values with generated embeddings
53+
documents = await this.replaceRawValuesWithEmbeddingsIfNecessary({
54+
database,
55+
collection,
56+
documents,
57+
embeddingParameters,
58+
});
59+
60+
await this.session.vectorSearchEmbeddingsManager.assertFieldsHaveCorrectEmbeddings(
61+
{ database, collection },
62+
documents
63+
);
5564

5665
const result = await provider.insertMany(database, collection, documents);
5766
const content = formatUntrustedData(
@@ -63,4 +72,84 @@ export class InsertManyTool extends MongoDBToolBase {
6372
content,
6473
};
6574
}
75+
76+
private async replaceRawValuesWithEmbeddingsIfNecessary({
77+
database,
78+
collection,
79+
documents,
80+
embeddingParameters,
81+
}: {
82+
database: string;
83+
collection: string;
84+
documents: Document[];
85+
embeddingParameters?: z.infer<typeof zSupportedEmbeddingParametersWithInput>;
86+
}): Promise<Document[]> {
87+
// If no embedding parameters or no input specified, return documents as-is
88+
if (!embeddingParameters?.input || embeddingParameters.input.length === 0) {
89+
return documents;
90+
}
91+
92+
// Get vector search indexes for the collection
93+
const vectorIndexes = await this.session.vectorSearchEmbeddingsManager.embeddingsForNamespace({
94+
database,
95+
collection,
96+
});
97+
98+
// Ensure for inputted fields, the vector search index exists.
99+
for (const input of embeddingParameters.input) {
100+
for (const fieldPath of Object.keys(input)) {
101+
if (!vectorIndexes.some((index) => index.path === fieldPath)) {
102+
throw new MongoDBError(
103+
ErrorCodes.AtlasVectorSearchInvalidQuery,
104+
`Field '${fieldPath}' does not have a vector search index in collection ${database}.${collection}. Only fields with vector search indexes can have embeddings generated.`
105+
);
106+
}
107+
}
108+
}
109+
110+
// We make one call to generate embeddings for all documents at once to avoid making too many API calls.
111+
const flattenedEmbeddingsInput = embeddingParameters.input.flatMap((documentInput, index) =>
112+
Object.entries(documentInput).map(([fieldPath, rawTextValue]) => ({
113+
fieldPath,
114+
rawTextValue,
115+
documentIndex: index,
116+
}))
117+
);
118+
119+
const generatedEmbeddings = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
120+
rawValues: flattenedEmbeddingsInput.map(({ rawTextValue }) => rawTextValue) as string[],
121+
embeddingParameters,
122+
inputType: "document",
123+
});
124+
125+
const processedDocuments: Document[] = [...documents];
126+
127+
for (const [index, { fieldPath, documentIndex }] of flattenedEmbeddingsInput.entries()) {
128+
if (!processedDocuments[documentIndex]) {
129+
throw new MongoDBError(ErrorCodes.Unexpected, `Document at index ${documentIndex} does not exist.`);
130+
}
131+
// Ensure no nested fields are present in the field path.
132+
this.deleteFieldPath(processedDocuments[documentIndex], fieldPath);
133+
processedDocuments[documentIndex][fieldPath] = generatedEmbeddings[index];
134+
}
135+
136+
return processedDocuments;
137+
}
138+
139+
// Delete a specified field path from a document using dot notation.
140+
private deleteFieldPath(document: Record<string, unknown>, fieldPath: string): void {
141+
const parts = fieldPath.split(".");
142+
let current: Record<string, unknown> = document;
143+
for (let i = 0; i < parts.length; i++) {
144+
const part = parts[i];
145+
const key = part as keyof typeof current;
146+
if (!current[key]) {
147+
return;
148+
} else if (i === parts.length - 1) {
149+
delete current[key];
150+
} else {
151+
current = current[key] as Record<string, unknown>;
152+
}
153+
}
154+
}
66155
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,22 +276,37 @@ export class AggregateTool extends MongoDBToolBase {
276276
const embeddingParameters = vectorSearchStage.embeddingParameters;
277277
delete vectorSearchStage.embeddingParameters;
278278

279-
const [embeddings] = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
279+
await this.session.vectorSearchEmbeddingsManager.assertVectorSearchIndexExists({
280280
database,
281281
collection,
282282
path: vectorSearchStage.path,
283+
});
284+
285+
const [embeddings] = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
283286
rawValues: [vectorSearchStage.queryVector],
284287
embeddingParameters,
285288
inputType: "query",
286289
});
287290

291+
if (!embeddings) {
292+
throw new MongoDBError(
293+
ErrorCodes.AtlasVectorSearchInvalidQuery,
294+
"Failed to generate embeddings for the query vector."
295+
);
296+
}
297+
288298
// $vectorSearch.queryVector can be a BSON.Binary: that it's not either number or an array.
289299
// It's not exactly valid from the LLM perspective (they can't provide binaries).
290300
// That's why we overwrite the stage in an untyped way, as what we expose and what LLMs can use is different.
291-
vectorSearchStage.queryVector = embeddings as number[];
301+
vectorSearchStage.queryVector = embeddings as string | number[];
292302
}
293303
}
294304

305+
await this.session.vectorSearchEmbeddingsManager.assertFieldsHaveCorrectEmbeddings(
306+
{ database, collection },
307+
pipeline
308+
);
309+
295310
return pipeline;
296311
}
297312

0 commit comments

Comments
 (0)