Skip to content

Commit

Permalink
Adding Maximal Marginal Relevance method to MongoDBAtlasVectorSearch (#…
Browse files Browse the repository at this point in the history
…2071)

* Refactored MongoDBAtlasVectorSearch to use abstract similaritySearchVectorWithScore for MMR

Co-authored-by: Simone DM <simone.dellamarca@mongodb.com>

* Fixed formatting

* Fixed formatting

* Fixed formatting

* Added optional maxMarginalRelevanceSearch method to base vectorstore

* Change max marginal relevance args

* Docs

---------

Co-authored-by: Simone DM <simone.dellamarca@mongodb.com>
Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
3 people authored Aug 5, 2023
1 parent 66ec957 commit 96c9fe5
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 8 deletions.
28 changes: 28 additions & 0 deletions langchain/src/vectorstores/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type AddDocumentOptions = Record<string, any>;

export type MaxMarginalRelevanceSearchOptions<FilterType> = {
k: number;
fetchK: number;
lambda: number;
filter?: FilterType;
};

export interface VectorStoreRetrieverInput<V extends VectorStore>
extends BaseRetrieverInput {
vectorStore: V;
Expand Down Expand Up @@ -126,6 +133,27 @@ export abstract class VectorStore extends Serializable {
);
}

/**
* Return documents selected using the maximal marginal relevance.
* Maximal marginal relevance optimizes for similarity to the query AND diversity
* among selected documents.
*
* @param {string} query - Text to look up documents similar to.
* @param {number} options.k - Number of documents to return.
* @param {number} options.fetchK - Number of documents to fetch before passing to the MMR algorithm.
* @param {number} options.lambda - Number between 0 and 1 that determines the degree of diversity among the results,
* where 0 corresponds to maximum diversity and 1 to minimum diversity.
* @param {this["FilterType"]} options.filter - Optional filter
* @param _callbacks
*
* @returns {Promise<Document[]>} - List of documents selected by maximal marginal relevance.
*/
async maxMarginalRelevanceSearch?(
query: string,
options: MaxMarginalRelevanceSearchOptions<this["FilterType"]>,
_callbacks: Callbacks | undefined // implement passing to embedQuery later
): Promise<Document[]>;

static fromTexts(
_texts: string[],
_metadatas: object[] | object,
Expand Down
87 changes: 80 additions & 7 deletions langchain/src/vectorstores/mongodb_atlas.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { Collection, Document as MongoDBDocument } from "mongodb";
import { VectorStore } from "./base.js";
import { MaxMarginalRelevanceSearchOptions, VectorStore } from "./base.js";
import { Embeddings } from "../embeddings/base.js";
import { Document } from "../document.js";
import { maximalMarginalRelevance } from "../util/math.js";

export type MongoDBAtlasVectorSearchLibArgs = {
readonly collection: Collection<MongoDBDocument>;
Expand All @@ -13,6 +14,7 @@ export type MongoDBAtlasVectorSearchLibArgs = {
type MongoDBAtlasFilter = {
preFilter?: MongoDBDocument;
postFilterPipeline?: MongoDBDocument[];
includeEmbeddings?: boolean;
} & MongoDBDocument;

export class MongoDBAtlasVectorSearch extends VectorStore {
Expand Down Expand Up @@ -68,9 +70,15 @@ export class MongoDBAtlasVectorSearch extends VectorStore {

let preFilter: MongoDBDocument | undefined;
let postFilterPipeline: MongoDBDocument[] | undefined;
if (filter?.preFilter || filter?.postFilterPipeline) {
let includeEmbeddings: boolean | undefined;
if (
filter?.preFilter ||
filter?.postFilterPipeline ||
filter?.includeEmbeddings
) {
preFilter = filter.preFilter;
postFilterPipeline = filter.postFilterPipeline;
includeEmbeddings = filter.includeEmbeddings || false;
} else preFilter = filter;

if (preFilter) {
Expand All @@ -84,28 +92,93 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
},
},
{
$project: {
[this.embeddingKey]: 0,
$set: {
score: { $meta: "searchScore" },
},
},
];

if (!includeEmbeddings) {
const removeEmbeddingsStage = {
$project: {
[this.embeddingKey]: 0,
},
};
pipeline.push(removeEmbeddingsStage);
}

if (postFilterPipeline) {
pipeline.push(...postFilterPipeline);
}
const results = this.collection.aggregate(pipeline);

const ret: [Document, number][] = [];
for await (const result of results) {
const text = result[this.textKey];
delete result[this.textKey];
const { score, ...metadata } = result;
const { score, [this.textKey]: text, ...metadata } = result;
ret.push([new Document({ pageContent: text, metadata }), score]);
}

return ret;
}

/**
* Return documents selected using the maximal marginal relevance.
* Maximal marginal relevance optimizes for similarity to the query AND diversity
* among selected documents.
*
* @param {string} query - Text to look up documents similar to.
* @param {number} options.k - Number of documents to return.
* @param {number} options.fetchK=20- Number of documents to fetch before passing to the MMR algorithm.
* @param {number} options.lambda=0.5 - Number between 0 and 1 that determines the degree of diversity among the results,
* where 0 corresponds to maximum diversity and 1 to minimum diversity.
* @param {MongoDBAtlasFilter} options.filter - Optional Atlas Search operator to pre-filter on document fields
* or post-filter following the knnBeta search.
*
* @returns {Promise<Document[]>} - List of documents selected by maximal marginal relevance.
*/
async maxMarginalRelevanceSearch(
query: string,
options: MaxMarginalRelevanceSearchOptions<this["FilterType"]>
): Promise<Document[]> {
const queryEmbedding = await this.embeddings.embedQuery(query);

// preserve the original value of includeEmbeddings
const includeEmbeddingsFlag = options.filter?.includeEmbeddings || false;

// update filter to include embeddings, as they will be used in MMR
const includeEmbeddingsFilter = {
...options.filter,
includeEmbeddings: true,
};

const resultDocs = await this.similaritySearchVectorWithScore(
queryEmbedding,
options.fetchK,
includeEmbeddingsFilter
);

const embeddingList = resultDocs.map(
(doc) => doc[0].metadata[this.embeddingKey]
);

const mmrIndexes = maximalMarginalRelevance(
queryEmbedding,
embeddingList,
options.lambda,
options.k
);

return mmrIndexes.map((idx) => {
const doc = resultDocs[idx][0];

// remove embeddings if they were not requested originally
if (!includeEmbeddingsFlag) {
delete doc.metadata[this.embeddingKey];
}
return doc;
});
}

static async fromTexts(
texts: string[],
metadatas: object[] | object,
Expand Down
44 changes: 43 additions & 1 deletion langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,49 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => {
});

const docs = await retriever.getRelevantDocuments("That fence is purple");
console.log(docs);
expect(docs).toEqual([]);
} finally {
await client.close();
}
});

test.skip("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
expect(process.env.MONGODB_ATLAS_URI).toBeDefined();
expect(
process.env.OPENAI_API_KEY || process.env.AZURE_OPENAI_API_KEY
).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const client = new MongoClient(process.env.MONGODB_ATLAS_URI!);
try {
const namespace = "langchain.test";
const [dbName, collectionName] = namespace.split(".");
const collection = client.db(dbName).collection(collectionName);

await collection.deleteMany({});

const texts = ["foo", "foo", "foy"];
const vectorStore = await MongoDBAtlasVectorSearch.fromTexts(
texts,
{},
new CohereEmbeddings(),
{ collection }
);

// we sleep 2 seconds to make sure the index in atlas has replicated the new documents
await sleep(2000);

const output = await vectorStore.maxMarginalRelevanceSearch("foo", {
k: 10,
fetchK: 20,
lambda: 0.1,
});

expect(output).toHaveLength(texts.length);

const actual = output.map((doc) => doc.pageContent);
const expected = ["foo", "foy", "foo"];
expect(actual).toEqual(expected);
} finally {
await client.close();
}
Expand Down

1 comment on commit 96c9fe5

@vercel
Copy link

@vercel vercel bot commented on 96c9fe5 Aug 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.