diff --git a/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx b/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx
index 771921ba6de2..32389a0d2746 100644
--- a/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx
+++ b/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx
@@ -8,7 +8,8 @@ sidebar_class_name: node-only
Only available on Node.js.
:::
-Langchain supports MongoDB Atlas as a vector store.
+LangChain.js supports MongoDB Atlas as a vector store, and supports both standard similarity search and maximal marginal relevance search,
+which takes a combination of documents are most similar to the inputs, then reranks and optimizes for diversity.
## Setup
@@ -69,3 +70,9 @@ import Ingestion from "@examples/indexes/vector_stores/mongodb_atlas_fromTexts.t
import Search from "@examples/indexes/vector_stores/mongodb_atlas_search.ts";
{Search}
+
+### Maximal marginal relevance
+
+import MMRExample from "@examples/indexes/vector_stores/mongodb_mmr.ts";
+
+{MMRExample}
\ No newline at end of file
diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts
index 710fe626f4bf..8e3044fd51be 100755
--- a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts
+++ b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts
@@ -2,23 +2,21 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { MongoClient } from "mongodb";
-export const run = async () => {
- const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
- const namespace = "langchain.test";
- const [dbName, collectionName] = namespace.split(".");
- const collection = client.db(dbName).collection(collectionName);
+const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
+const namespace = "langchain.test";
+const [dbName, collectionName] = namespace.split(".");
+const collection = client.db(dbName).collection(collectionName);
- await MongoDBAtlasVectorSearch.fromTexts(
- ["Hello world", "Bye bye", "What's this?"],
- [{ id: 2 }, { id: 1 }, { id: 3 }],
- new CohereEmbeddings(),
- {
- collection,
- indexName: "default", // The name of the Atlas search index. Defaults to "default"
- textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
- embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
- }
- );
+await MongoDBAtlasVectorSearch.fromTexts(
+ ["Hello world", "Bye bye", "What's this?"],
+ [{ id: 2 }, { id: 1 }, { id: 3 }],
+ new CohereEmbeddings(),
+ {
+ collection,
+ indexName: "default", // The name of the Atlas search index. Defaults to "default"
+ textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
+ embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
+ }
+);
- await client.close();
-};
+await client.close();
diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts
index 5f071a1fe346..7dbb17bff18b 100755
--- a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts
+++ b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts
@@ -2,21 +2,19 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { MongoClient } from "mongodb";
-export const run = async () => {
- const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
- const namespace = "langchain.test";
- const [dbName, collectionName] = namespace.split(".");
- const collection = client.db(dbName).collection(collectionName);
+const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
+const namespace = "langchain.test";
+const [dbName, collectionName] = namespace.split(".");
+const collection = client.db(dbName).collection(collectionName);
- const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
- collection,
- indexName: "default", // The name of the Atlas search index. Defaults to "default"
- textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
- embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
- });
+const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
+ collection,
+ indexName: "default", // The name of the Atlas search index. Defaults to "default"
+ textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
+ embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
+});
- const resultOne = await vectorStore.similaritySearch("Hello world", 1);
- console.log(resultOne);
+const resultOne = await vectorStore.similaritySearch("Hello world", 1);
+console.log(resultOne);
- await client.close();
-};
+await client.close();
diff --git a/examples/src/indexes/vector_stores/mongodb_mmr.ts b/examples/src/indexes/vector_stores/mongodb_mmr.ts
new file mode 100644
index 000000000000..496f152fcc14
--- /dev/null
+++ b/examples/src/indexes/vector_stores/mongodb_mmr.ts
@@ -0,0 +1,37 @@
+import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
+import { CohereEmbeddings } from "langchain/embeddings/cohere";
+import { MongoClient } from "mongodb";
+
+const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
+const namespace = "langchain.test";
+const [dbName, collectionName] = namespace.split(".");
+const collection = client.db(dbName).collection(collectionName);
+
+const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
+ collection,
+ indexName: "default", // The name of the Atlas search index. Defaults to "default"
+ textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
+ embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
+});
+
+const resultOne = await vectorStore.maxMarginalRelevanceSearch("Hello world", {
+ k: 4,
+ fetchK: 20, // The number of documents to return on initial fetch
+});
+console.log(resultOne);
+
+// Using MMR in a vector store retriever
+
+const retriever = await vectorStore.asRetriever({
+ searchType: "mmr",
+ searchKwargs: {
+ fetchK: 20,
+ lambda: 0.1,
+ },
+});
+
+const retrieverOutput = await retriever.getRelevantDocuments("Hello world");
+
+console.log(retrieverOutput);
+
+await client.close();
diff --git a/langchain/src/retrievers/hyde.ts b/langchain/src/retrievers/hyde.ts
index 102649baea7b..86aa81852482 100644
--- a/langchain/src/retrievers/hyde.ts
+++ b/langchain/src/retrievers/hyde.ts
@@ -20,11 +20,11 @@ export type PromptKey =
| "trec-news"
| "mr-tydi";
-export interface HydeRetrieverOptions
- extends VectorStoreRetrieverInput {
- llm: BaseLanguageModel;
- promptTemplate?: BasePromptTemplate | PromptKey;
-}
+export type HydeRetrieverOptions =
+ VectorStoreRetrieverInput & {
+ llm: BaseLanguageModel;
+ promptTemplate?: BasePromptTemplate | PromptKey;
+ };
export class HydeRetriever<
V extends VectorStore = VectorStore
@@ -86,17 +86,17 @@ export function getPromptTemplateFromKey(key: PromptKey): BasePromptTemplate {
switch (key) {
case "websearch":
- template = `Please write a passage to answer the question
+ template = `Please write a passage to answer the question
Question: {question}
Passage:`;
break;
case "scifact":
- template = `Please write a scientific paper passage to support/refute the claim
+ template = `Please write a scientific paper passage to support/refute the claim
Claim: {question}
Passage:`;
break;
case "arguana":
- template = `Please write a counter argument for the passage
+ template = `Please write a counter argument for the passage
Passage: {question}
Counter Argument:`;
break;
diff --git a/langchain/src/vectorstores/base.ts b/langchain/src/vectorstores/base.ts
index 7605b6e388e2..f27dc7e85734 100644
--- a/langchain/src/vectorstores/base.ts
+++ b/langchain/src/vectorstores/base.ts
@@ -12,17 +12,33 @@ type AddDocumentOptions = Record;
export type MaxMarginalRelevanceSearchOptions = {
k: number;
- fetchK: number;
- lambda: number;
+ fetchK?: number;
+ lambda?: number;
filter?: FilterType;
};
-export interface VectorStoreRetrieverInput
- extends BaseRetrieverInput {
- vectorStore: V;
- k?: number;
- filter?: V["FilterType"];
-}
+export type VectorStoreRetrieverMMRSearchKwargs = {
+ fetchK?: number;
+ lambda?: number;
+};
+
+export type VectorStoreRetrieverInput =
+ BaseRetrieverInput &
+ (
+ | {
+ vectorStore: V;
+ k?: number;
+ filter?: V["FilterType"];
+ searchType?: "similarity";
+ }
+ | {
+ vectorStore: V;
+ k?: number;
+ filter?: V["FilterType"];
+ searchType: "mmr";
+ searchKwargs?: VectorStoreRetrieverMMRSearchKwargs;
+ }
+ );
export class VectorStoreRetriever<
V extends VectorStore = VectorStore
@@ -35,6 +51,10 @@ export class VectorStoreRetriever<
k = 4;
+ searchType = "similarity";
+
+ searchKwargs?: VectorStoreRetrieverMMRSearchKwargs;
+
filter?: V["FilterType"];
_vectorstoreType(): string {
@@ -45,13 +65,33 @@ export class VectorStoreRetriever<
super(fields);
this.vectorStore = fields.vectorStore;
this.k = fields.k ?? this.k;
+ this.searchType = fields.searchType ?? this.searchType;
this.filter = fields.filter;
+ if (fields.searchType === "mmr") {
+ this.searchKwargs = fields.searchKwargs;
+ }
}
async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise {
+ if (this.searchType === "mmr") {
+ if (typeof this.vectorStore.maxMarginalRelevanceSearch !== "function") {
+ throw new Error(
+ `The vector store backing this retriever, ${this._vectorstoreType()} does not support max marginal relevance search.`
+ );
+ }
+ return this.vectorStore.maxMarginalRelevanceSearch(
+ query,
+ {
+ k: this.k,
+ filter: this.filter,
+ ...this.searchKwargs,
+ },
+ runManager?.getChild("vectorstore")
+ );
+ }
return this.vectorStore.similaritySearch(
query,
this.k,
@@ -196,7 +236,7 @@ export abstract class VectorStore extends Serializable {
callbacks,
});
} else {
- return new VectorStoreRetriever({
+ const params = {
vectorStore: this,
k: kOrFields?.k,
filter: kOrFields?.filter,
@@ -204,7 +244,15 @@ export abstract class VectorStore extends Serializable {
metadata: kOrFields?.metadata,
verbose: kOrFields?.verbose,
callbacks: kOrFields?.callbacks,
- });
+ searchType: kOrFields?.searchType,
+ };
+ if (kOrFields?.searchType === "mmr") {
+ return new VectorStoreRetriever({
+ ...params,
+ searchKwargs: kOrFields.searchKwargs,
+ });
+ }
+ return new VectorStoreRetriever({ ...params });
}
}
}
diff --git a/langchain/src/vectorstores/mongodb_atlas.ts b/langchain/src/vectorstores/mongodb_atlas.ts
index 9c6dc187b373..5e6324ce8d52 100755
--- a/langchain/src/vectorstores/mongodb_atlas.ts
+++ b/langchain/src/vectorstores/mongodb_atlas.ts
@@ -140,6 +140,8 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
query: string,
options: MaxMarginalRelevanceSearchOptions
): Promise {
+ const { k, fetchK = 20, lambda = 0.5, filter } = options;
+
const queryEmbedding = await this.embeddings.embedQuery(query);
// preserve the original value of includeEmbeddings
@@ -147,13 +149,13 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
// update filter to include embeddings, as they will be used in MMR
const includeEmbeddingsFilter = {
- ...options.filter,
+ ...filter,
includeEmbeddings: true,
};
const resultDocs = await this.similaritySearchVectorWithScore(
queryEmbedding,
- options.fetchK,
+ fetchK,
includeEmbeddingsFilter
);
@@ -164,8 +166,8 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
const mmrIndexes = maximalMarginalRelevance(
queryEmbedding,
embeddingList,
- options.lambda,
- options.k
+ lambda,
+ k
);
return mmrIndexes.map((idx) => {
diff --git a/langchain/src/vectorstores/tests/mongo.int.test.ts b/langchain/src/vectorstores/tests/mongo.int.test.ts
deleted file mode 100644
index 8a853a870893..000000000000
--- a/langchain/src/vectorstores/tests/mongo.int.test.ts
+++ /dev/null
@@ -1,101 +0,0 @@
-/* eslint-disable no-process-env */
-/* eslint-disable no-promise-executor-return */
-
-import { test, expect } from "@jest/globals";
-import { MongoClient } from "mongodb";
-import { CohereEmbeddings } from "../../embeddings/cohere.js";
-import { MongoVectorStore, MongoVectorStoreQueryExtension } from "../mongo.js";
-
-import { Document } from "../../document.js";
-
-/**
- * The following json can be used to create an index in atlas for cohere embeddings:
-{
- "mappings": {
- "fields": {
- "embedding": [
- {
- "dimensions": 1024,
- "similarity": "euclidean",
- "type": "knnVector"
- }
- ]
- }
- }
-}
- */
-
-test.skip("MongoVectorStore with external ids", async () => {
- expect(process.env.MONGO_URI).toBeDefined();
-
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
- const client = new MongoClient(process.env.MONGO_URI!);
-
- try {
- const collection = client.db("langchain").collection("test");
-
- const vectorStore = new MongoVectorStore(new CohereEmbeddings(), {
- client,
- collection,
- // indexName: "default", // make sure that this matches the index name in atlas if not using "default"
- });
-
- expect(vectorStore).toBeDefined();
-
- // check if the database is empty
- const count = await collection.countDocuments();
-
- const justInserted = count === 0;
- if (justInserted) {
- await vectorStore.addDocuments([
- { pageContent: "Dogs are tough.", metadata: { a: 1 } },
- { pageContent: "Cats have fluff.", metadata: { b: 1 } },
- { pageContent: "What is a sandwich?", metadata: { c: 1 } },
- { pageContent: "That fence is purple.", metadata: { d: 1, e: 2 } },
- ]);
- }
-
- // This test is awkward because the index in atlas takes time to index new documents
- // This means from a fresh insert the query will return nothing
- let triesLeft = 4;
-
- let results: Document[] = [];
- while (triesLeft > 0) {
- results = await vectorStore.similaritySearch("Sandwich", 1);
-
- if (justInserted && results.length === 0 && triesLeft > 0) {
- // wait and try again in hopes that the indexing has finished
- await new Promise((resolve) => setTimeout(resolve, 3000));
- }
-
- triesLeft -= 1;
- }
-
- expect(results).toEqual([
- { pageContent: "What is a sandwich?", metadata: { c: 1 } },
- ]);
-
- // we can filter the search with custom pipeline stages
- const filter: MongoVectorStoreQueryExtension = {
- postQueryPipelineSteps: [
- {
- $match: {
- "metadata.e": { $exists: true },
- },
- },
- ],
- };
-
- const filteredResults = await vectorStore.similaritySearch(
- "Sandwich",
- 4,
- filter
- );
-
- expect(filteredResults).toEqual([
- { pageContent: "That fence is purple.", metadata: { d: 1, e: 2 } },
- ]);
- } finally {
- await client.close();
- }
-});
diff --git a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
index f395b133940f..bb2cc9a38ad3 100755
--- a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
+++ b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
@@ -131,6 +131,33 @@ test.skip("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () =
const actual = output.map((doc) => doc.pageContent);
const expected = ["foo", "foy", "foo"];
expect(actual).toEqual(expected);
+
+ const standardRetriever = await vectorStore.asRetriever();
+
+ const standardRetrieverOutput =
+ await standardRetriever.getRelevantDocuments("foo");
+ expect(output).toHaveLength(texts.length);
+
+ const standardRetrieverActual = standardRetrieverOutput.map(
+ (doc) => doc.pageContent
+ );
+ const standardRetrieverExpected = ["foo", "foo", "foy"];
+ expect(standardRetrieverActual).toEqual(standardRetrieverExpected);
+
+ const retriever = await vectorStore.asRetriever({
+ searchType: "mmr",
+ searchKwargs: {
+ fetchK: 20,
+ lambda: 0.1,
+ },
+ });
+
+ const retrieverOutput = await retriever.getRelevantDocuments("foo");
+ expect(output).toHaveLength(texts.length);
+
+ const retrieverActual = retrieverOutput.map((doc) => doc.pageContent);
+ const retrieverExpected = ["foo", "foy", "foo"];
+ expect(retrieverActual).toEqual(retrieverExpected);
} finally {
await client.close();
}