Skip to content

Commit

Permalink
Recursive k similarity search (#2410)
Browse files Browse the repository at this point in the history
* feat(vectorstores/base.ts): add support to recursive similarity searches

* chore(docs): update language and formatting in the Recursive Similarity Search section

* fix: fix docs

* feat: move to SimilarityScoreThresholdVectorStoreRetriever

* docs(vector_stores): remove unused import CodeBlock
docs(vector_stores): remove similarityFilter parameter from asRetriever method

* refactor(conversational_retrieval_chain.int.test.ts): remove eslint-disable comment

* refactor(vectorstores/base.ts): remove unused import and similarityFilter parameter in VectorStoreRetrieverInput
refactor(vectorstores/base.ts): simplify object creation in VectorStoreRetriever method

* refactor(similarity-score-threshold.ts): remove debug console.log statement

* fix: use the right name convention

* fix: use the right name convention

* fix(similarity_score_threshold.ts): fix infinite loop condition in SimilarityScoreThresholdVectorStoreRetriever

* Refactor

* fix(vectorstores/index.ts): rename SimilarityScoreThresholdVectorStoreRetriever to ScoreThresholdVectorStoreRetriever

* Rename, fix entrypoints

* Add similarity score threshold retriever example

* Fix name

* Fix formatting

* Update docs

---------

Co-authored-by: João Melo <jopcmelo@gmail.com>
  • Loading branch information
jacoblee93 and joaopcm authored Aug 26, 2023
1 parent 23c20da commit c6899ae
Show file tree
Hide file tree
Showing 16 changed files with 268 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Similarity Score Threshold

A problem some people may face is that when doing a similarity search, you have to supply a `k` value. This value is responsible for bringing N similar results back to you. But what if you don't know the `k` value? What if you want the system to return all the possible results?

In a real-world scenario, let's imagine a super long document created by a product manager which describes a product. In this document, we could have 10, 15, 20, 100 or more features described. How to know the correct `k` value so the system returns all the possible results to the question "What are all the features that product X has?".

To solve this problem, LangChain offers a feature called Recursive Similarity Search. With it, you can do a similarity search without having to rely solely on the `k` value. The system will return all the possible results to your question, based on the minimum similarity percentage you want.

It is possible to use the Recursive Similarity Search by using a vector store as retriever.

## Usage

import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/retrievers/similarity_score_threshold.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>
14 changes: 10 additions & 4 deletions examples/src/guides/expression_language/cookbook_retriever_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,28 @@ const serializeDocs = (docs: Document[]) =>

const languageChain = RunnableSequence.from([
{
// Every property in the map receives the same input,
// so we need to extract just the standalone question to pass into the retriever.
// We then serialize the retrieved docs into a string to pass into the prompt.
context: RunnableSequence.from([
(input: LanguageChainInput) => input.question,
retriever,
serializeDocs,
]),
question: (input: LanguageChainInput) => input.question,
language: (input: LanguageChainInput) => input.language,
context: (input: LanguageChainInput) =>
retriever.pipe(serializeDocs).invoke(input.question),
},
languagePrompt,
model,
new StringOutputParser(),
]);

const result2 = await languageChain.invoke({
const result = await languageChain.invoke({
question: "What is the powerhouse of the cell?",
language: "German",
});

console.log(result2);
console.log(result);

/*
"Mitochondrien sind das Kraftwerk der Zelle."
Expand Down
54 changes: 54 additions & 0 deletions examples/src/retrievers/similarity_score_threshold.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { ScoreThresholdRetriever } from "langchain/retrievers/score_threshold";

const vectorStore = await MemoryVectorStore.fromTexts(
[
"Buildings are made out of brick",
"Buildings are made out of wood",
"Buildings are made out of stone",
"Buildings are made out of atoms",
"Buildings are made out of building materials",
"Cars are made out of metal",
"Cars are made out of plastic",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const retriever = ScoreThresholdRetriever.fromVectorStore(vectorStore, {
minSimilarityScore: 0.9, // Finds results with at least this similarity score
maxK: 100, // The maximum K value to use. Use it based to your chunk size to make sure you don't run out of tokens
kIncrement: 2, // How much to increase K by each time. It'll fetch N results, then N + kIncrement, then N + kIncrement * 2, etc.
});

const result = await retriever.getRelevantDocuments(
"What are buildings made out of?"
);

console.log(result);

/*
[
Document {
pageContent: 'Buildings are made out of building materials',
metadata: { id: 5 }
},
Document {
pageContent: 'Buildings are made out of wood',
metadata: { id: 2 }
},
Document {
pageContent: 'Buildings are made out of brick',
metadata: { id: 1 }
},
Document {
pageContent: 'Buildings are made out of stone',
metadata: { id: 3 }
},
Document {
pageContent: 'Buildings are made out of atoms',
metadata: { id: 4 }
}
]
*/
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ retrievers/document_compressors/chain_extract.d.ts
retrievers/hyde.cjs
retrievers/hyde.js
retrievers/hyde.d.ts
retrievers/score_threshold.cjs
retrievers/score_threshold.js
retrievers/score_threshold.d.ts
retrievers/self_query.cjs
retrievers/self_query.js
retrievers/self_query.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,9 @@
"retrievers/hyde.cjs",
"retrievers/hyde.js",
"retrievers/hyde.d.ts",
"retrievers/score_threshold.cjs",
"retrievers/score_threshold.js",
"retrievers/score_threshold.d.ts",
"retrievers/self_query.cjs",
"retrievers/self_query.js",
"retrievers/self_query.d.ts",
Expand Down Expand Up @@ -1827,6 +1830,11 @@
"import": "./retrievers/hyde.js",
"require": "./retrievers/hyde.cjs"
},
"./retrievers/score_threshold": {
"types": "./retrievers/score_threshold.d.ts",
"import": "./retrievers/score_threshold.js",
"require": "./retrievers/score_threshold.cjs"
},
"./retrievers/self_query": {
"types": "./retrievers/self_query.d.ts",
"import": "./retrievers/self_query.js",
Expand Down
1 change: 1 addition & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ const entrypoints = {
"retrievers/document_compressors/chain_extract":
"retrievers/document_compressors/chain_extract",
"retrievers/hyde": "retrievers/hyde",
"retrievers/score_threshold": "retrievers/score_threshold",
"retrievers/self_query": "retrievers/self_query/index",
"retrievers/self_query/chroma": "retrievers/self_query/chroma",
"retrievers/self_query/functional": "retrievers/self_query/functional",
Expand Down
1 change: 1 addition & 0 deletions langchain/src/load/import_map.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export * as retrievers__parent_document from "../retrievers/parent_document.js";
export * as retrievers__time_weighted from "../retrievers/time_weighted.js";
export * as retrievers__document_compressors__chain_extract from "../retrievers/document_compressors/chain_extract.js";
export * as retrievers__hyde from "../retrievers/hyde.js";
export * as retrievers__score_threshold from "../retrievers/score_threshold.js";
export * as retrievers__vespa from "../retrievers/vespa.js";
export * as cache from "../cache/index.js";
export * as stores__doc__in_memory from "../stores/doc/in_memory.js";
Expand Down
57 changes: 57 additions & 0 deletions langchain/src/retrievers/score_threshold.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import {
VectorStore,
VectorStoreRetriever,
VectorStoreRetrieverInput,
} from "../vectorstores/base.js";
import { Document } from "../document.js";

export type ScoreThresholdRetrieverInput<V extends VectorStore> = Omit<
VectorStoreRetrieverInput<V>,
"k"
> & {
maxK?: number;
kIncrement?: number;
minSimilarityScore: number;
};

export class ScoreThresholdRetriever<
V extends VectorStore
> extends VectorStoreRetriever<V> {
minSimilarityScore: number;

kIncrement = 10;

maxK = 100;

constructor(input: ScoreThresholdRetrieverInput<V>) {
super(input);
this.maxK = input.maxK ?? this.maxK;
this.minSimilarityScore =
input.minSimilarityScore ?? this.minSimilarityScore;
this.kIncrement = input.kIncrement ?? this.kIncrement;
}

async getRelevantDocuments(query: string): Promise<Document[]> {
let currentK = 0;
let filteredResults: [Document, number][] = [];
do {
currentK += this.kIncrement;
const results = await this.vectorStore.similaritySearchWithScore(
query,
currentK,
this.filter
);
filteredResults = results.filter(
([, score]) => score >= this.minSimilarityScore
);
} while (filteredResults.length >= currentK && currentK < this.maxK);
return filteredResults.map((documents) => documents[0]).slice(0, this.maxK);
}

static fromVectorStore<V extends VectorStore>(
vectorStore: V,
options: Omit<ScoreThresholdRetrieverInput<V>, "vectorStore">
) {
return new this<V>({ ...options, vectorStore });
}
}
111 changes: 111 additions & 0 deletions langchain/src/retrievers/tests/score_threshold.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* eslint-disable no-process-env */
import { expect, test } from "@jest/globals";
import { ConversationalRetrievalQAChain } from "../../chains/conversational_retrieval_chain.js";
import { OpenAIEmbeddings } from "../../embeddings/openai.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { BufferMemory } from "../../memory/buffer_memory.js";
import { MemoryVectorStore } from "../../vectorstores/memory.js";
import { ScoreThresholdRetriever } from "../score_threshold.js";

test("ConversationalRetrievalQAChain.fromLLM should use its vector store recursively until it gets all the similar results with the minimum similarity score provided", async () => {
const vectorStore = await MemoryVectorStore.fromTexts(
[
"Buildings are made out of brick",
"Buildings are made out of wood",
"Buildings are made out of stone",
"Cars are made out of metal",
"Cars are made out of plastic",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const model = new ChatOpenAI({
modelName: "gpt-3.5-turbo",
temperature: 0,
});

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
ScoreThresholdRetriever.fromVectorStore(vectorStore, {
minSimilarityScore: 0.9,
kIncrement: 1,
}),
{
returnSourceDocuments: true,
memory: new BufferMemory({
memoryKey: "chat_history",
inputKey: "question",
outputKey: "text",
}),
}
);
const res = await chain.call({
question: "Buildings are made out of what?",
});

console.log("response:", res);

expect(res).toEqual(
expect.objectContaining({
text: expect.any(String),
sourceDocuments: expect.arrayContaining([
expect.objectContaining({
metadata: expect.objectContaining({
id: 1,
}),
}),
expect.objectContaining({
metadata: expect.objectContaining({
id: 2,
}),
}),
expect.objectContaining({
metadata: expect.objectContaining({
id: 3,
}),
}),
]),
})
);
});

test("ConversationalRetrievalQAChain.fromLLM should use its vector store to get up to X results that matches the provided similarity score", async () => {
const vectorStore = await MemoryVectorStore.fromTexts(
[
"Buildings are made out of brick",
"Buildings are made out of wood",
"Buildings are made out of stone",
"Cars are made out of metal",
"Cars are made out of plastic",
],
[{ id: 1 }, { id: 2 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const model = new ChatOpenAI({
modelName: "gpt-3.5-turbo",
temperature: 0,
});

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
ScoreThresholdRetriever.fromVectorStore(vectorStore, {
minSimilarityScore: 0.9,
maxK: 2,
}),
{
returnSourceDocuments: true,
memory: new BufferMemory({
memoryKey: "chat_history",
inputKey: "question",
outputKey: "text",
}),
}
);
const res = await chain.call({
question: "Buildings are made out of what?",
});

expect(res.sourceDocuments).toHaveLength(2);
});
1 change: 1 addition & 0 deletions langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"src/retrievers/time_weighted.ts",
"src/retrievers/document_compressors/chain_extract.ts",
"src/retrievers/hyde.ts",
"src/retrievers/score_threshold.ts",
"src/retrievers/self_query/index.ts",
"src/retrievers/self_query/chroma.ts",
"src/retrievers/self_query/functional.ts",
Expand Down
1 change: 1 addition & 0 deletions test-exports-cf/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export * from "langchain/retrievers/parent_document";
export * from "langchain/retrievers/time_weighted";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/retrievers/score_threshold";
export * from "langchain/retrievers/vespa";
export * from "langchain/cache";
export * from "langchain/stores/doc/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-cjs/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const retrievers_parent_document = require("langchain/retrievers/parent_document
const retrievers_time_weighted = require("langchain/retrievers/time_weighted");
const retrievers_document_compressors_chain_extract = require("langchain/retrievers/document_compressors/chain_extract");
const retrievers_hyde = require("langchain/retrievers/hyde");
const retrievers_score_threshold = require("langchain/retrievers/score_threshold");
const retrievers_vespa = require("langchain/retrievers/vespa");
const cache = require("langchain/cache");
const stores_doc_in_memory = require("langchain/stores/doc/in_memory");
Expand Down
1 change: 1 addition & 0 deletions test-exports-esbuild/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import * as retrievers_parent_document from "langchain/retrievers/parent_documen
import * as retrievers_time_weighted from "langchain/retrievers/time_weighted";
import * as retrievers_document_compressors_chain_extract from "langchain/retrievers/document_compressors/chain_extract";
import * as retrievers_hyde from "langchain/retrievers/hyde";
import * as retrievers_score_threshold from "langchain/retrievers/score_threshold";
import * as retrievers_vespa from "langchain/retrievers/vespa";
import * as cache from "langchain/cache";
import * as stores_doc_in_memory from "langchain/stores/doc/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-esm/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import * as retrievers_parent_document from "langchain/retrievers/parent_documen
import * as retrievers_time_weighted from "langchain/retrievers/time_weighted";
import * as retrievers_document_compressors_chain_extract from "langchain/retrievers/document_compressors/chain_extract";
import * as retrievers_hyde from "langchain/retrievers/hyde";
import * as retrievers_score_threshold from "langchain/retrievers/score_threshold";
import * as retrievers_vespa from "langchain/retrievers/vespa";
import * as cache from "langchain/cache";
import * as stores_doc_in_memory from "langchain/stores/doc/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vercel/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export * from "langchain/retrievers/parent_document";
export * from "langchain/retrievers/time_weighted";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/retrievers/score_threshold";
export * from "langchain/retrievers/vespa";
export * from "langchain/cache";
export * from "langchain/stores/doc/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vite/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export * from "langchain/retrievers/parent_document";
export * from "langchain/retrievers/time_weighted";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/retrievers/score_threshold";
export * from "langchain/retrievers/vespa";
export * from "langchain/cache";
export * from "langchain/stores/doc/in_memory";
Expand Down

1 comment on commit c6899ae

@vercel
Copy link

@vercel vercel bot commented on c6899ae Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.