Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions libs/langchain-mongodb/src/tests/vectorstores.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import { beforeAll, expect, jest, test } from "@jest/globals";
import { Collection, MongoClient } from "mongodb";
import { setTimeout } from "timers/promises";
import { OpenAIEmbeddings } from "@langchain/openai";
import { OpenAIEmbeddings, AzureOpenAIEmbeddings } from "@langchain/openai";
import { Document } from "@langchain/core/documents";
// eslint-disable-next-line import/no-extraneous-dependencies
import { Document as BSONDocument } from "bson";

import { EmbeddingsInterface } from "@langchain/core/embeddings";
import { MongoDBAtlasVectorSearch } from "../vectorstores.js";
import { isUsingLocalAtlas, uri, waitForIndexToBeQueryable } from "./utils.js";

Expand Down Expand Up @@ -102,8 +103,18 @@ class PatchedVectorStore extends MongoDBAtlasVectorSearch {
}
}

function getEmbeddings() {
if (process.env.AZURE_OPENAI_API_KEY) {
return new AzureOpenAIEmbeddings({
model: "text-embedding-3-small",
azureOpenAIApiDeploymentName: "openai/deployments/text-embedding-3-small",
});
}
return new OpenAIEmbeddings();
}

test("MongoDBAtlasVectorSearch with external ids", async () => {
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
collection,
});

Expand Down Expand Up @@ -166,7 +177,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
const vectorStore = await PatchedVectorStore.fromTexts(
texts,
{},
new OpenAIEmbeddings(),
getEmbeddings(),
{ collection, indexName: "default" }
);

Expand Down Expand Up @@ -215,7 +226,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
});

test("MongoDBAtlasVectorSearch upsert", async () => {
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
collection,
});

Expand Down Expand Up @@ -253,15 +264,15 @@ test("MongoDBAtlasVectorSearch upsert", async () => {

describe("MongoDBAtlasVectorSearch Constructor", () => {
test("initializes with minimal configuration", () => {
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
collection,
});
expect(vectorStore).toBeDefined();
});

test("initializes with custom index name", () => {
const customIndexName = "custom_index";
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
collection,
indexName: customIndexName,
});
Expand All @@ -271,7 +282,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
});

test("initializes with custom field names", () => {
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
collection,
textKey: "content",
embeddingKey: "vector",
Expand All @@ -287,7 +298,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
});

test("initializes AsyncCaller with custom parameters", () => {
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
collection,
maxConcurrency: 5,
maxRetries: 3,
Expand All @@ -304,7 +315,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
});

describe("addVectors method", () => {
let embeddings: OpenAIEmbeddings;
let embeddings: EmbeddingsInterface;
let vectorStore: PatchedVectorStore;
let vectors: number[][];
const documents = [
Expand All @@ -313,8 +324,8 @@ describe("addVectors method", () => {
];

beforeEach(async () => {
embeddings = new OpenAIEmbeddings();
vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
embeddings = getEmbeddings();
vectorStore = new PatchedVectorStore(getEmbeddings(), {
collection,
});
vectors = await embeddings.embedDocuments(["test 1", "test 2"]);
Expand Down Expand Up @@ -388,14 +399,14 @@ describe("addVectors method", () => {
});

describe("addDocuments method", () => {
let embeddings: OpenAIEmbeddings;
let embeddings: EmbeddingsInterface;
let vectorStore: PatchedVectorStore;
const documents = [
new Document({ pageContent: "test 1" }),
new Document({ pageContent: "test 2" }),
];
beforeEach(async () => {
embeddings = new OpenAIEmbeddings();
embeddings = getEmbeddings();
vectorStore = new PatchedVectorStore(embeddings, {
collection,
});
Expand Down Expand Up @@ -474,10 +485,10 @@ describe("addDocuments method", () => {
});

describe("similaritySearchVectorWithScore method", () => {
let embeddings: OpenAIEmbeddings;
let embeddings: EmbeddingsInterface;
let vectorStore: PatchedVectorStore;
beforeEach(async () => {
embeddings = new OpenAIEmbeddings();
embeddings = getEmbeddings();
vectorStore = new PatchedVectorStore(embeddings, {
collection,
});
Expand Down Expand Up @@ -717,7 +728,7 @@ describe("delete method", () => {
});

test("removes documents by ids", async () => {
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
collection,
});

Expand Down Expand Up @@ -767,7 +778,7 @@ describe("delete method", () => {
});

test("ignores non-existent ids", async () => {
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
collection,
});

Expand All @@ -783,10 +794,10 @@ describe("delete method", () => {

describe("Static Methods", () => {
describe("fromTexts", () => {
let embeddings: OpenAIEmbeddings;
let embeddings: EmbeddingsInterface;
const texts = ["text1", "text2", "text3"];
beforeEach(() => {
embeddings = new OpenAIEmbeddings();
embeddings = getEmbeddings();
});

test("populates a vector store from strings with a metadata object", async () => {
Expand Down Expand Up @@ -838,7 +849,7 @@ describe("Static Methods", () => {
];
const store = await MongoDBAtlasVectorSearch.fromDocuments(
documents,
new OpenAIEmbeddings(),
getEmbeddings(),
{ collection }
);
expect(store).toBeInstanceOf(MongoDBAtlasVectorSearch);
Expand All @@ -850,11 +861,9 @@ describe("Static Methods", () => {
new Document({ pageContent: "doc2", metadata: { source: "source2" } }),
];

await MongoDBAtlasVectorSearch.fromDocuments(
documents,
new OpenAIEmbeddings(),
{ collection }
);
await MongoDBAtlasVectorSearch.fromDocuments(documents, getEmbeddings(), {
collection,
});

const results = await collection
.find({}, { projection: { _id: 0, text: 1, embedding: 1 } })
Expand All @@ -873,11 +882,10 @@ describe("Static Methods", () => {
new Document({ pageContent: "doc2", metadata: { source: "source2" } }),
];

await MongoDBAtlasVectorSearch.fromDocuments(
documents,
new OpenAIEmbeddings(),
{ collection, ids: ["custom1", "custom2"] }
);
await MongoDBAtlasVectorSearch.fromDocuments(documents, getEmbeddings(), {
collection,
ids: ["custom1", "custom2"],
});

const results = await collection
.find({}, { projection: { _id: 1, text: 1 } })
Expand All @@ -896,7 +904,7 @@ describe("Static Methods", () => {
new Document({ pageContent: "doc1" }),
new Document({ pageContent: "doc2" }),
],
new OpenAIEmbeddings(),
getEmbeddings(),
{ collection, ids: ["id1", "id2"] }
);

Expand All @@ -905,7 +913,7 @@ describe("Static Methods", () => {
new Document({ pageContent: "updated 1" }),
new Document({ pageContent: "updated 2" }),
],
new OpenAIEmbeddings(),
getEmbeddings(),
{ collection, ids: ["id1", "id2"] }
);

Expand Down
Loading