Skip to content

Commit

Permalink
Adds search type to vectorstore retrievers for MMR, adds docs (#2170)
Browse files Browse the repository at this point in the history
* Refactored MongoDBAtlasVectorSearch to use abstract similaritySearchVectorWithScore for MMR

Co-authored-by: Simone DM <simone.dellamarca@mongodb.com>

* Fixed formatting

* Fixed formatting

* Fixed formatting

* Added optional maxMarginalRelevanceSearch method to base vectorstore

* Change max marginal relevance args

* Docs

* Adds search type to vectorstore retrievers, docs

* Update docs

* Docs

* Add back in mongo

* Format

* Fix bug

* Fix unit test

* Fix bug

---------

Co-authored-by: archie-swif <artem.ryabokon@gmail.com>
Co-authored-by: Simone DM <simone.dellamarca@mongodb.com>
  • Loading branch information
3 people authored Aug 5, 2023
1 parent 96c9fe5 commit ac0e6c6
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ sidebar_class_name: node-only
Only available on Node.js.
:::

Langchain supports MongoDB Atlas as a vector store.
LangChain.js supports MongoDB Atlas as a vector store, and supports both standard similarity search and maximal marginal relevance search,
which takes a combination of documents are most similar to the inputs, then reranks and optimizes for diversity.

## Setup

Expand Down Expand Up @@ -69,3 +70,9 @@ import Ingestion from "@examples/indexes/vector_stores/mongodb_atlas_fromTexts.t
import Search from "@examples/indexes/vector_stores/mongodb_atlas_search.ts";

<CodeBlock language="typescript">{Search}</CodeBlock>

### Maximal marginal relevance

import MMRExample from "@examples/indexes/vector_stores/mongodb_mmr.ts";

<CodeBlock language="typescript">{MMRExample}</CodeBlock>
34 changes: 16 additions & 18 deletions examples/src/indexes/vector_stores/mongodb_atlas_fromTexts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { MongoClient } from "mongodb";

export const run = async () => {
const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);
const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

await MongoDBAtlasVectorSearch.fromTexts(
["Hello world", "Bye bye", "What's this?"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
new CohereEmbeddings(),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
}
);
await MongoDBAtlasVectorSearch.fromTexts(
["Hello world", "Bye bye", "What's this?"],
[{ id: 2 }, { id: 1 }, { id: 3 }],
new CohereEmbeddings(),
{
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
}
);

await client.close();
};
await client.close();
28 changes: 13 additions & 15 deletions examples/src/indexes/vector_stores/mongodb_atlas_search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { MongoClient } from "mongodb";

export const run = async () => {
const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);
const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});
const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});

const resultOne = await vectorStore.similaritySearch("Hello world", 1);
console.log(resultOne);
const resultOne = await vectorStore.similaritySearch("Hello world", 1);
console.log(resultOne);

await client.close();
};
await client.close();
37 changes: 37 additions & 0 deletions examples/src/indexes/vector_stores/mongodb_mmr.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { MongoDBAtlasVectorSearch } from "langchain/vectorstores/mongodb_atlas";
import { CohereEmbeddings } from "langchain/embeddings/cohere";
import { MongoClient } from "mongodb";

const client = new MongoClient(process.env.MONGODB_ATLAS_URI || "");
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

const vectorStore = new MongoDBAtlasVectorSearch(new CohereEmbeddings(), {
collection,
indexName: "default", // The name of the Atlas search index. Defaults to "default"
textKey: "text", // The name of the collection field containing the raw content. Defaults to "text"
embeddingKey: "embedding", // The name of the collection field containing the embedded text. Defaults to "embedding"
});

const resultOne = await vectorStore.maxMarginalRelevanceSearch("Hello world", {
k: 4,
fetchK: 20, // The number of documents to return on initial fetch
});
console.log(resultOne);

// Using MMR in a vector store retriever

const retriever = await vectorStore.asRetriever({
searchType: "mmr",
searchKwargs: {
fetchK: 20,
lambda: 0.1,
},
});

const retrieverOutput = await retriever.getRelevantDocuments("Hello world");

console.log(retrieverOutput);

await client.close();
16 changes: 8 additions & 8 deletions langchain/src/retrievers/hyde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ export type PromptKey =
| "trec-news"
| "mr-tydi";

export interface HydeRetrieverOptions<V extends VectorStore>
extends VectorStoreRetrieverInput<V> {
llm: BaseLanguageModel;
promptTemplate?: BasePromptTemplate | PromptKey;
}
export type HydeRetrieverOptions<V extends VectorStore> =
VectorStoreRetrieverInput<V> & {
llm: BaseLanguageModel;
promptTemplate?: BasePromptTemplate | PromptKey;
};

export class HydeRetriever<
V extends VectorStore = VectorStore
Expand Down Expand Up @@ -86,17 +86,17 @@ export function getPromptTemplateFromKey(key: PromptKey): BasePromptTemplate {

switch (key) {
case "websearch":
template = `Please write a passage to answer the question
template = `Please write a passage to answer the question
Question: {question}
Passage:`;
break;
case "scifact":
template = `Please write a scientific paper passage to support/refute the claim
template = `Please write a scientific paper passage to support/refute the claim
Claim: {question}
Passage:`;
break;
case "arguana":
template = `Please write a counter argument for the passage
template = `Please write a counter argument for the passage
Passage: {question}
Counter Argument:`;
break;
Expand Down
68 changes: 58 additions & 10 deletions langchain/src/vectorstores/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,33 @@ type AddDocumentOptions = Record<string, any>;

export type MaxMarginalRelevanceSearchOptions<FilterType> = {
k: number;
fetchK: number;
lambda: number;
fetchK?: number;
lambda?: number;
filter?: FilterType;
};

export interface VectorStoreRetrieverInput<V extends VectorStore>
extends BaseRetrieverInput {
vectorStore: V;
k?: number;
filter?: V["FilterType"];
}
export type VectorStoreRetrieverMMRSearchKwargs = {
fetchK?: number;
lambda?: number;
};

export type VectorStoreRetrieverInput<V extends VectorStore> =
BaseRetrieverInput &
(
| {
vectorStore: V;
k?: number;
filter?: V["FilterType"];
searchType?: "similarity";
}
| {
vectorStore: V;
k?: number;
filter?: V["FilterType"];
searchType: "mmr";
searchKwargs?: VectorStoreRetrieverMMRSearchKwargs;
}
);

export class VectorStoreRetriever<
V extends VectorStore = VectorStore
Expand All @@ -35,6 +51,10 @@ export class VectorStoreRetriever<

k = 4;

searchType = "similarity";

searchKwargs?: VectorStoreRetrieverMMRSearchKwargs;

filter?: V["FilterType"];

_vectorstoreType(): string {
Expand All @@ -45,13 +65,33 @@ export class VectorStoreRetriever<
super(fields);
this.vectorStore = fields.vectorStore;
this.k = fields.k ?? this.k;
this.searchType = fields.searchType ?? this.searchType;
this.filter = fields.filter;
if (fields.searchType === "mmr") {
this.searchKwargs = fields.searchKwargs;
}
}

async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<Document[]> {
if (this.searchType === "mmr") {
if (typeof this.vectorStore.maxMarginalRelevanceSearch !== "function") {
throw new Error(
`The vector store backing this retriever, ${this._vectorstoreType()} does not support max marginal relevance search.`
);
}
return this.vectorStore.maxMarginalRelevanceSearch(
query,
{
k: this.k,
filter: this.filter,
...this.searchKwargs,
},
runManager?.getChild("vectorstore")
);
}
return this.vectorStore.similaritySearch(
query,
this.k,
Expand Down Expand Up @@ -196,15 +236,23 @@ export abstract class VectorStore extends Serializable {
callbacks,
});
} else {
return new VectorStoreRetriever({
const params = {
vectorStore: this,
k: kOrFields?.k,
filter: kOrFields?.filter,
tags: [...(kOrFields?.tags ?? []), this._vectorstoreType()],
metadata: kOrFields?.metadata,
verbose: kOrFields?.verbose,
callbacks: kOrFields?.callbacks,
});
searchType: kOrFields?.searchType,
};
if (kOrFields?.searchType === "mmr") {
return new VectorStoreRetriever({
...params,
searchKwargs: kOrFields.searchKwargs,
});
}
return new VectorStoreRetriever({ ...params });
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions langchain/src/vectorstores/mongodb_atlas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,22 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
query: string,
options: MaxMarginalRelevanceSearchOptions<this["FilterType"]>
): Promise<Document[]> {
const { k, fetchK = 20, lambda = 0.5, filter } = options;

const queryEmbedding = await this.embeddings.embedQuery(query);

// preserve the original value of includeEmbeddings
const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false;

// update filter to include embeddings, as they will be used in MMR
const includeEmbeddingsFilter = {
...options.filter,
...filter,
includeEmbeddings: true,
};

const resultDocs = await this.similaritySearchVectorWithScore(
queryEmbedding,
options.fetchK,
fetchK,
includeEmbeddingsFilter
);

Expand All @@ -164,8 +166,8 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
const mmrIndexes = maximalMarginalRelevance(
queryEmbedding,
embeddingList,
options.lambda,
options.k
lambda,
k
);

return mmrIndexes.map((idx) => {
Expand Down
Loading

1 comment on commit ac0e6c6

@vercel
Copy link

@vercel vercel bot commented on ac0e6c6 Aug 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.