Skip to content

Commit

Permalink
Fix Issue 2108: Atlas vector store retriever breaking on getRelevantD…
Browse files Browse the repository at this point in the history
…ocuments (#2120)

* Fix: Atlas vector store retriever breaking on getRelevantDocuments - Issue #2108

* Adds partial shim, update interface

* Changed input verification on new shim

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
sdellama and jacoblee93 authored Jul 30, 2023
1 parent c792069 commit c14cbcc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
2 changes: 1 addition & 1 deletion langchain/src/vectorstores/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ export abstract class VectorStore extends Serializable {
}

asRetriever(
kOrFields?: number | VectorStoreRetrieverInput<this>,
kOrFields?: number | Partial<VectorStoreRetrieverInput<this>>,
filter?: this["FilterType"],
callbacks?: Callbacks,
tags?: string[],
Expand Down
55 changes: 26 additions & 29 deletions langchain/src/vectorstores/mongodb_atlas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@ import { Embeddings } from "../embeddings/base.js";
import { Document } from "../document.js";

export type MongoDBAtlasVectorSearchLibArgs = {
collection: Collection<MongoDBDocument>;
indexName?: string;
textKey?: string;
embeddingKey?: string;
readonly collection: Collection<MongoDBDocument>;
readonly indexName?: string;
readonly textKey?: string;
readonly embeddingKey?: string;
};

type MongoDBAtlasFilter = {
preFilter?: MongoDBDocument;
postFilterPipeline?: MongoDBDocument[];
} & MongoDBDocument;

export class MongoDBAtlasVectorSearch extends VectorStore {
declare FilterType: MongoDBDocument;
declare FilterType: MongoDBAtlasFilter;

collection: Collection<MongoDBDocument>;
private readonly collection: Collection<MongoDBDocument>;

indexName: string;
private readonly indexName: string;

textKey: string;
private readonly textKey: string;

embeddingKey: string;
private readonly embeddingKey: string;

_vectorstoreType(): string {
return "mongodb_atlas";
Expand All @@ -28,9 +33,9 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
constructor(embeddings: Embeddings, args: MongoDBAtlasVectorSearchLibArgs) {
super(embeddings, args);
this.collection = args.collection;
this.indexName = args.indexName || "default";
this.textKey = args.textKey || "text";
this.embeddingKey = args.embeddingKey || "embedding";
this.indexName = args.indexName ?? "default";
this.textKey = args.textKey ?? "text";
this.embeddingKey = args.embeddingKey ?? "embedding";
}

async addVectors(vectors: number[][], documents: Document[]): Promise<void> {
Expand All @@ -53,14 +58,21 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
async similaritySearchVectorWithScore(
query: number[],
k: number,
preFilter?: MongoDBDocument,
postFilterPipeline?: MongoDBDocument[]
filter?: MongoDBAtlasFilter
): Promise<[Document, number][]> {
const knnBeta: MongoDBDocument = {
vector: query,
path: this.embeddingKey,
k,
};

let preFilter: MongoDBDocument | undefined;
let postFilterPipeline: MongoDBDocument[] | undefined;
if (filter?.preFilter || filter?.postFilterPipeline) {
preFilter = filter.preFilter;
postFilterPipeline = filter.postFilterPipeline;
} else preFilter = filter;

if (preFilter) {
knnBeta.filter = preFilter;
}
Expand Down Expand Up @@ -94,21 +106,6 @@ export class MongoDBAtlasVectorSearch extends VectorStore {
return ret;
}

async similaritySearch(
query: string,
k: number,
preFilter?: MongoDBDocument,
postFilterPipeline?: MongoDBDocument[]
): Promise<Document[]> {
const results = await this.similaritySearchVectorWithScore(
await this.embeddings.embedQuery(query),
k,
preFilter,
postFilterPipeline
);
return results.map((result) => result[0]);
}

static async fromTexts(
texts: string[],
metadatas: object[] | object,
Expand Down
9 changes: 9 additions & 0 deletions langchain/src/vectorstores/tests/mongodb_atlas.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ test.skip("MongoDBAtlasVectorSearch with external ids", async () => {
);

expect(filteredResults).toEqual([]);

const retriever = vectorStore.asRetriever({
filter: {
preFilter,
},
});

const docs = await retriever.getRelevantDocuments("That fence is purple");
console.log(docs);
} finally {
await client.close();
}
Expand Down

1 comment on commit c14cbcc

@vercel
Copy link

@vercel vercel bot commented on c14cbcc Jul 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.