diff --git a/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx b/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx index 771921ba6de2..32389a0d2746 100644 --- a/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx +++ b/docs/extras/modules/data_connection/vectorstores/integrations/mongodb_atlas.mdx @@ -8,7 +8,8 @@ sidebar_class_name: node-only Only available on Node.js. ::: -Langchain supports MongoDB Atlas as a vector store. +LangChain.js supports MongoDB Atlas as a vector store, and supports both standard similarity search and maximal marginal relevance search, +which takes a combination of documents are most similar to the inputs, then reranks and optimizes for diversity. ## Setup @@ -69,3 +70,9 @@ import Ingestion from "@examples/indexes/vector_stores/mongodb_atlas_fromTexts.t import Search from "@examples/indexes/vector_stores/mongodb_atlas_search.ts"; {Search} + +### Maximal marginal relevance + +import MMRExample from "@examples/indexes/vector_stores/mongodb_mmr.ts"; + +{MMRExample} \ No newline at end of file diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts index 710fe626f4bf..8e3044fd51be 100755 --- a/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts +++ b/examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts @@ -2,23 +2,21 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas"; import { CohereEmbeddings } from "langchain/embeddings/cohere"; import { MongoClient } from "mongodb"; -export const run = async () => { - const client = new MongoClient(process.env.MONGODB_ATLAS_URI || ""); - const namespace = "langchain.test"; - const [dbName, collectionName] = namespace.split("."); - const collection = client.db(dbName).collection(collectionName); +const client = new MongoClient(process.env.MONGODB_ATLAS_URI || ""); +const namespace = "langchain.test"; +const [dbName, collectionName] = namespace.split("."); +const collection = client.db(dbName).collection(collectionName); - await MongoDBAtlasVectorSearch.fromTexts( - ["Hello world", "Bye bye", "What's this?"], - [{ id: 2 }, { id: 1 }, { id: 3 }], - new CohereEmbeddings(), - { - collection, - indexName: "default", // The name of the Atlas search index. Defaults to "default" - textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" - embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" - } - ); +await MongoDBAtlasVectorSearch.fromTexts( + ["Hello world", "Bye bye", "What's this?"], + [{ id: 2 }, { id: 1 }, { id: 3 }], + new CohereEmbeddings(), + { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" + } +); - await client.close(); -}; +await client.close(); diff --git a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts index 5f071a1fe346..7dbb17bff18b 100755 --- a/examples/src/indexes/vector_stores/mongodb_atlas_search.ts +++ b/examples/src/indexes/vector_stores/mongodb_atlas_search.ts @@ -2,21 +2,19 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas"; import { CohereEmbeddings } from "langchain/embeddings/cohere"; import { MongoClient } from "mongodb"; -export const run = async () => { - const client = new MongoClient(process.env.MONGODB_ATLAS_URI || ""); - const namespace = "langchain.test"; - const [dbName, collectionName] = namespace.split("."); - const collection = client.db(dbName).collection(collectionName); +const client = new MongoClient(process.env.MONGODB_ATLAS_URI || ""); +const namespace = "langchain.test"; +const [dbName, collectionName] = namespace.split("."); +const collection = client.db(dbName).collection(collectionName); - const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { - collection, - indexName: "default", // The name of the Atlas search index. Defaults to "default" - textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" - embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" - }); +const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" +}); - const resultOne = await vectorStore.similaritySearch("Hello world", 1); - console.log(resultOne); +const resultOne = await vectorStore.similaritySearch("Hello world", 1); +console.log(resultOne); - await client.close(); -}; +await client.close(); diff --git a/examples/src/indexes/vector_stores/mongodb_mmr.ts b/examples/src/indexes/vector_stores/mongodb_mmr.ts new file mode 100644 index 000000000000..496f152fcc14 --- /dev/null +++ b/examples/src/indexes/vector_stores/mongodb_mmr.ts @@ -0,0 +1,37 @@ +import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas"; +import { CohereEmbeddings } from "langchain/embeddings/cohere"; +import { MongoClient } from "mongodb"; + +const client = new MongoClient(process.env.MONGODB_ATLAS_URI || ""); +const namespace = "langchain.test"; +const [dbName, collectionName] = namespace.split("."); +const collection = client.db(dbName).collection(collectionName); + +const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), { + collection, + indexName: "default", // The name of the Atlas search index. Defaults to "default" + textKey: "text", // The name of the collection field containing the raw content. Defaults to "text" + embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding" +}); + +const resultOne = await vectorStore.maxMarginalRelevanceSearch("Hello world", { + k: 4, + fetchK: 20, // The number of documents to return on initial fetch +}); +console.log(resultOne); + +// Using MMR in a vector store retriever + +const retriever = await vectorStore.asRetriever({ + searchType: "mmr", + searchKwargs: { + fetchK: 20, + lambda: 0.1, + }, +}); + +const retrieverOutput = await retriever.getRelevantDocuments("Hello world"); + +console.log(retrieverOutput); + +await client.close(); diff --git a/langchain/src/retrievers/hyde.ts b/langchain/src/retrievers/hyde.ts index 102649baea7b..86aa81852482 100644 --- a/langchain/src/retrievers/hyde.ts +++ b/langchain/src/retrievers/hyde.ts @@ -20,11 +20,11 @@ export type PromptKey = | "trec-news" | "mr-tydi"; -export interface HydeRetrieverOptions - extends VectorStoreRetrieverInput { - llm: BaseLanguageModel; - promptTemplate?: BasePromptTemplate | PromptKey; -} +export type HydeRetrieverOptions = + VectorStoreRetrieverInput & { + llm: BaseLanguageModel; + promptTemplate?: BasePromptTemplate | PromptKey; + }; export class HydeRetriever< V extends VectorStore = VectorStore @@ -86,17 +86,17 @@ export function getPromptTemplateFromKey(key: PromptKey): BasePromptTemplate { switch (key) { case "websearch": - template = `Please write a passage to answer the question + template = `Please write a passage to answer the question Question: {question} Passage:`; break; case "scifact": - template = `Please write a scientific paper passage to support/refute the claim + template = `Please write a scientific paper passage to support/refute the claim Claim: {question} Passage:`; break; case "arguana": - template = `Please write a counter argument for the passage + template = `Please write a counter argument for the passage Passage: {question} Counter Argument:`; break; diff --git a/langchain/src/vectorstores/base.ts b/langchain/src/vectorstores/base.ts index 7605b6e388e2..f27dc7e85734 100644 --- a/langchain/src/vectorstores/base.ts +++ b/langchain/src/vectorstores/base.ts @@ -12,17 +12,33 @@ type AddDocumentOptions = Record; export type MaxMarginalRelevanceSearchOptions = { k: number; - fetchK: number; - lambda: number; + fetchK?: number; + lambda?: number; filter?: FilterType; }; -export interface VectorStoreRetrieverInput - extends BaseRetrieverInput { - vectorStore: V; - k?: number; - filter?: V["FilterType"]; -} +export type VectorStoreRetrieverMMRSearchKwargs = { + fetchK?: number; + lambda?: number; +}; + +export type VectorStoreRetrieverInput = + BaseRetrieverInput & + ( + | { + vectorStore: V; + k?: number; + filter?: V["FilterType"]; + searchType?: "similarity"; + } + | { + vectorStore: V; + k?: number; + filter?: V["FilterType"]; + searchType: "mmr"; + searchKwargs?: VectorStoreRetrieverMMRSearchKwargs; + } + ); export class VectorStoreRetriever< V extends VectorStore = VectorStore @@ -35,6 +51,10 @@ export class VectorStoreRetriever< k = 4; + searchType = "similarity"; + + searchKwargs?: VectorStoreRetrieverMMRSearchKwargs; + filter?: V["FilterType"]; _vectorstoreType(): string { @@ -45,13 +65,33 @@ export class VectorStoreRetriever< super(fields); this.vectorStore = fields.vectorStore; this.k = fields.k ?? this.k; + this.searchType = fields.searchType ?? this.searchType; this.filter = fields.filter; + if (fields.searchType === "mmr") { + this.searchKwargs = fields.searchKwargs; + } } async _getRelevantDocuments( query: string, runManager?: CallbackManagerForRetrieverRun ): Promise { + if (this.searchType === "mmr") { + if (typeof this.vectorStore.maxMarginalRelevanceSearch !== "function") { + throw new Error( + `The vector store backing this retriever, ${this._vectorstoreType()} does not support max marginal relevance search.` + ); + } + return this.vectorStore.maxMarginalRelevanceSearch( + query, + { + k: this.k, + filter: this.filter, + ...this.searchKwargs, + }, + runManager?.getChild("vectorstore") + ); + } return this.vectorStore.similaritySearch( query, this.k, @@ -196,7 +236,7 @@ export abstract class VectorStore extends Serializable { callbacks, }); } else { - return new VectorStoreRetriever({ + const params = { vectorStore: this, k: kOrFields?.k, filter: kOrFields?.filter, @@ -204,7 +244,15 @@ export abstract class VectorStore extends Serializable { metadata: kOrFields?.metadata, verbose: kOrFields?.verbose, callbacks: kOrFields?.callbacks, - }); + searchType: kOrFields?.searchType, + }; + if (kOrFields?.searchType === "mmr") { + return new VectorStoreRetriever({ + ...params, + searchKwargs: kOrFields.searchKwargs, + }); + } + return new VectorStoreRetriever({ ...params }); } } } diff --git a/langchain/src/vectorstores/mongodb_atlas.ts b/langchain/src/vectorstores/mongodb_atlas.ts index 9c6dc187b373..5e6324ce8d52 100755 --- a/langchain/src/vectorstores/mongodb_atlas.ts +++ b/langchain/src/vectorstores/mongodb_atlas.ts @@ -140,6 +140,8 @@ export class MongoDBAtlasVectorSearch extends VectorStore { query: string, options: MaxMarginalRelevanceSearchOptions ): Promise { + const { k, fetchK = 20, lambda = 0.5, filter } = options; + const queryEmbedding = await this.embeddings.embedQuery(query); // preserve the original value of includeEmbeddings @@ -147,13 +149,13 @@ export class MongoDBAtlasVectorSearch extends VectorStore { // update filter to include embeddings, as they will be used in MMR const includeEmbeddingsFilter = { - ...options.filter, + ...filter, includeEmbeddings: true, }; const resultDocs = await this.similaritySearchVectorWithScore( queryEmbedding, - options.fetchK, + fetchK, includeEmbeddingsFilter ); @@ -164,8 +166,8 @@ export class MongoDBAtlasVectorSearch extends VectorStore { const mmrIndexes = maximalMarginalRelevance( queryEmbedding, embeddingList, - options.lambda, - options.k + lambda, + k ); return mmrIndexes.map((idx) => { diff --git a/langchain/src/vectorstores/tests/mongo.int.test.ts b/langchain/src/vectorstores/tests/mongo.int.test.ts deleted file mode 100644 index 8a853a870893..000000000000 --- a/langchain/src/vectorstores/tests/mongo.int.test.ts +++ /dev/null @@ -1,101 +0,0 @@ -/* eslint-disable no-process-env */ -/* eslint-disable no-promise-executor-return */ - -import { test, expect } from "@jest/globals"; -import { MongoClient } from "mongodb"; -import { CohereEmbeddings } from "../../embeddings/cohere.js"; -import { MongoVectorStore, MongoVectorStoreQueryExtension } from "../mongo.js"; - -import { Document } from "../../document.js"; - -/** - * The following json can be used to create an index in atlas for cohere embeddings: -{ - "mappings": { - "fields": { - "embedding": [ - { - "dimensions": 1024, - "similarity": "euclidean", - "type": "knnVector" - } - ] - } - } -} - */ - -test.skip("MongoVectorStore with external ids", async () => { - expect(process.env.MONGO_URI).toBeDefined(); - - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const client = new MongoClient(process.env.MONGO_URI!); - - try { - const collection = client.db("langchain").collection("test"); - - const vectorStore = new MongoVectorStore(new CohereEmbeddings(), { - client, - collection, - // indexName: "default", // make sure that this matches the index name in atlas if not using "default" - }); - - expect(vectorStore).toBeDefined(); - - // check if the database is empty - const count = await collection.countDocuments(); - - const justInserted = count === 0; - if (justInserted) { - await vectorStore.addDocuments([ - { pageContent: "Dogs are tough.", metadata: { a: 1 } }, - { pageContent: "Cats have fluff.", metadata: { b: 1 } }, - { pageContent: "What is a sandwich?", metadata: { c: 1 } }, - { pageContent: "That fence is purple.", metadata: { d: 1, e: 2 } }, - ]); - } - - // This test is awkward because the index in atlas takes time to index new documents - // This means from a fresh insert the query will return nothing - let triesLeft = 4; - - let results: Document[] = []; - while (triesLeft > 0) { - results = await vectorStore.similaritySearch("Sandwich", 1); - - if (justInserted && results.length === 0 && triesLeft > 0) { - // wait and try again in hopes that the indexing has finished - await new Promise((resolve) => setTimeout(resolve, 3000)); - } - - triesLeft -= 1; - } - - expect(results).toEqual([ - { pageContent: "What is a sandwich?", metadata: { c: 1 } }, - ]); - - // we can filter the search with custom pipeline stages - const filter: MongoVectorStoreQueryExtension = { - postQueryPipelineSteps: [ - { - $match: { - "metadata.e": { $exists: true }, - }, - }, - ], - }; - - const filteredResults = await vectorStore.similaritySearch( - "Sandwich", - 4, - filter - ); - - expect(filteredResults).toEqual([ - { pageContent: "That fence is purple.", metadata: { d: 1, e: 2 } }, - ]); - } finally { - await client.close(); - } -}); diff --git a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts index f395b133940f..bb2cc9a38ad3 100755 --- a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts +++ b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts @@ -131,6 +131,33 @@ test.skip("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () = const actual = output.map((doc) => doc.pageContent); const expected = ["foo", "foy", "foo"]; expect(actual).toEqual(expected); + + const standardRetriever = await vectorStore.asRetriever(); + + const standardRetrieverOutput = + await standardRetriever.getRelevantDocuments("foo"); + expect(output).toHaveLength(texts.length); + + const standardRetrieverActual = standardRetrieverOutput.map( + (doc) => doc.pageContent + ); + const standardRetrieverExpected = ["foo", "foo", "foy"]; + expect(standardRetrieverActual).toEqual(standardRetrieverExpected); + + const retriever = await vectorStore.asRetriever({ + searchType: "mmr", + searchKwargs: { + fetchK: 20, + lambda: 0.1, + }, + }); + + const retrieverOutput = await retriever.getRelevantDocuments("foo"); + expect(output).toHaveLength(texts.length); + + const retrieverActual = retrieverOutput.map((doc) => doc.pageContent); + const retrieverExpected = ["foo", "foy", "foo"]; + expect(retrieverActual).toEqual(retrieverExpected); } finally { await client.close(); }