From c14cbcc696bf9ca3d27513e5b762cf56aa1c00ab Mon Sep 17 00:00:00 2001 From: Simone Della Marca <97287757+sdellama@users.noreply.github.com> Date: Sun, 30 Jul 2023 19:00:06 +0200 Subject: [PATCH] Fix Issue 2108: Atlas vector store retriever breaking on getRelevantDocuments (#2120) * Fix: Atlas vector store retriever breaking on getRelevantDocuments - Issue #2108 * Adds partial shim, update interface * Changed input verification on new shim --------- Co-authored-by: jacoblee93 --- langchain/src/vectorstores/base.ts | 2 +- langchain/src/vectorstores/mongodb_atlas.ts | 55 +++++++++---------- .../tests/mongodb_atlas.int.test.ts | 9 +++ 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/langchain/src/vectorstores/base.ts b/langchain/src/vectorstores/base.ts index 5f096da7ddbc..b13a368b187b 100644 --- a/langchain/src/vectorstores/base.ts +++ b/langchain/src/vectorstores/base.ts @@ -150,7 +150,7 @@ export abstract class VectorStore extends Serializable { } asRetriever( - kOrFields?: number | VectorStoreRetrieverInput, + kOrFields?: number | Partial>, filter?: this["FilterType"], callbacks?: Callbacks, tags?: string[], diff --git a/langchain/src/vectorstores/mongodb_atlas.ts b/langchain/src/vectorstores/mongodb_atlas.ts index a1765b383826..31f9ac160c6e 100755 --- a/langchain/src/vectorstores/mongodb_atlas.ts +++ b/langchain/src/vectorstores/mongodb_atlas.ts @@ -4,22 +4,27 @@ import { Embeddings } from "../embeddings/base.js"; import { Document } from "../document.js"; export type MongoDBAtlasVectorSearchLibArgs = { - collection: Collection; - indexName?: string; - textKey?: string; - embeddingKey?: string; + readonly collection: Collection; + readonly indexName?: string; + readonly textKey?: string; + readonly embeddingKey?: string; }; +type MongoDBAtlasFilter = { + preFilter?: MongoDBDocument; + postFilterPipeline?: MongoDBDocument[]; +} & MongoDBDocument; + export class MongoDBAtlasVectorSearch extends VectorStore { - declare FilterType: MongoDBDocument; + declare FilterType: MongoDBAtlasFilter; - collection: Collection; + private readonly collection: Collection; - indexName: string; + private readonly indexName: string; - textKey: string; + private readonly textKey: string; - embeddingKey: string; + private readonly embeddingKey: string; _vectorstoreType(): string { return "mongodb_atlas"; @@ -28,9 +33,9 @@ export class MongoDBAtlasVectorSearch extends VectorStore { constructor(embeddings: Embeddings, args: MongoDBAtlasVectorSearchLibArgs) { super(embeddings, args); this.collection = args.collection; - this.indexName = args.indexName || "default"; - this.textKey = args.textKey || "text"; - this.embeddingKey = args.embeddingKey || "embedding"; + this.indexName = args.indexName ?? "default"; + this.textKey = args.textKey ?? "text"; + this.embeddingKey = args.embeddingKey ?? "embedding"; } async addVectors(vectors: number[][], documents: Document[]): Promise { @@ -53,14 +58,21 @@ export class MongoDBAtlasVectorSearch extends VectorStore { async similaritySearchVectorWithScore( query: number[], k: number, - preFilter?: MongoDBDocument, - postFilterPipeline?: MongoDBDocument[] + filter?: MongoDBAtlasFilter ): Promise<[Document, number][]> { const knnBeta: MongoDBDocument = { vector: query, path: this.embeddingKey, k, }; + + let preFilter: MongoDBDocument | undefined; + let postFilterPipeline: MongoDBDocument[] | undefined; + if (filter?.preFilter || filter?.postFilterPipeline) { + preFilter = filter.preFilter; + postFilterPipeline = filter.postFilterPipeline; + } else preFilter = filter; + if (preFilter) { knnBeta.filter = preFilter; } @@ -94,21 +106,6 @@ export class MongoDBAtlasVectorSearch extends VectorStore { return ret; } - async similaritySearch( - query: string, - k: number, - preFilter?: MongoDBDocument, - postFilterPipeline?: MongoDBDocument[] - ): Promise { - const results = await this.similaritySearchVectorWithScore( - await this.embeddings.embedQuery(query), - k, - preFilter, - postFilterPipeline - ); - return results.map((result) => result[0]); - } - static async fromTexts( texts: string[], metadatas: object[] | object, diff --git a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts index 48fa6ee5b88a..f715616c84e3 100755 --- a/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts +++ b/langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts @@ -80,6 +80,15 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => { ); expect(filteredResults).toEqual([]); + + const retriever = vectorStore.asRetriever({ + filter: { + preFilter, + }, + }); + + const docs = await retriever.getRelevantDocuments("That fence is purple"); + console.log(docs); } finally { await client.close(); }