From b780ab4613963dc3c1ea0bdc014909bf6d7977fc Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Tue, 30 May 2023 16:09:38 -0700 Subject: [PATCH] Adds support for built-in memory to ConversationalQARetrievalChain (#1463) * Adds support for built-in memory to ConversationalQARetrievalChain * Fix typo in docs * More docs updates * Fix typo in docs --- .../conversational_retrieval.mdx | 21 ++- examples/src/chains/conversational_qa.ts | 12 +- .../chains/conversational_qa_customization.ts | 40 +++++ .../conversational_qa_external_memory.ts | 34 ++++ .../chains/conversational_retrieval_chain.ts | 33 +++- ...conversational_retrieval_chain.int.test.ts | 152 ++++++++++++++++++ 6 files changed, 279 insertions(+), 13 deletions(-) create mode 100644 examples/src/chains/conversational_qa_customization.ts create mode 100644 examples/src/chains/conversational_qa_external_memory.ts diff --git a/docs/docs/modules/chains/index_related_chains/conversational_retrieval.mdx b/docs/docs/modules/chains/index_related_chains/conversational_retrieval.mdx index 1f1e1e73a821..bed4d68f5686 100644 --- a/docs/docs/modules/chains/index_related_chains/conversational_retrieval.mdx +++ b/docs/docs/modules/chains/index_related_chains/conversational_retrieval.mdx @@ -10,15 +10,15 @@ import ConvoRetrievalQAExample from "@examples/chains/conversational_qa.ts"; The `ConversationalRetrievalQA` chain builds on `RetrievalQAChain` to provide a chat history component. -It requires two inputs: a question and the chat history. It first combines the chat history and the question into a standalone question, then looks up relevant documents from the retriever, and then passes those documents and the question to a question answering chain to return a response. +It first combines the chat history (either explicitly passed in or retrieved from the provided memory) and the question into a standalone question, then looks up relevant documents from the retriever, and finally passes those documents and the question to a question answering chain to return a response. -To create one, you will need a retriever. In the below example, we will create one from a vectorstore, which can be created from embeddings. +To create one, you will need a retriever. In the below example, we will create one from a vector store, which can be created from embeddings. import Example from "@examples/chains/conversational_qa.ts"; {ConvoRetrievalQAExample} -In this code snippet, the fromLLM method of the `ConversationalRetrievalQAChain` class has the following signature: +In the above code snippet, the fromLLM method of the `ConversationalRetrievalQAChain` class has the following signature: ```typescript static fromLLM( @@ -46,4 +46,17 @@ Here's an explanation of each of the attributes of the options object: You can see [documentation about the usable fields here](/docs/api/chains/types/QAChainParams). - `returnSourceDocuments`: A boolean value that indicates whether the `ConversationalRetrievalQAChain` should return the source documents that were used to retrieve the answer. If set to true, the documents will be included in the result returned by the call() method. This can be useful if you want to allow the user to see the sources used to generate the answer. If not set, the default value will be false. -In summary, the `questionGeneratorChainOptions`, `qaChainOptions`, and `returnSourceDocuments` options allow the user to customize the behavior of the `ConversationalRetrievalQAChain` +Here's a customization example using a faster LLM to generate questions and a slower, more comprehensive LLM for the final answer: + +import ConvoQACustomizationExample from "@examples/chains/conversational_qa_customization.ts"; + +{ConvoQACustomizationExample} + +## External Memory + +If you'd like to format the chat history in a specific way, you can also pass the chat history in explicitly by omitting the `memory` option and passing in +a `chat_history` string directly into the `chain.call` method: + +import ConvoQAExternalMemoryExample from "@examples/chains/conversational_qa_external_memory.ts"; + +{ConvoQAExternalMemoryExample} diff --git a/examples/src/chains/conversational_qa.ts b/examples/src/chains/conversational_qa.ts index 1a86dd399ebe..900fc510795c 100644 --- a/examples/src/chains/conversational_qa.ts +++ b/examples/src/chains/conversational_qa.ts @@ -3,6 +3,7 @@ import { ConversationalRetrievalQAChain } from "langchain/chains"; import { HNSWLib } from "langchain/vectorstores/hnswlib"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import { BufferMemory } from "langchain/memory"; import * as fs from "fs"; export const run = async () => { @@ -18,17 +19,20 @@ export const run = async () => { /* Create the chain */ const chain = ConversationalRetrievalQAChain.fromLLM( model, - vectorStore.asRetriever() + vectorStore.asRetriever(), + { + memory: new BufferMemory({ + memoryKey: "chat_history", // Must be set to "chat_history" + }), + } ); /* Ask it a question */ const question = "What did the president say about Justice Breyer?"; - const res = await chain.call({ question, chat_history: [] }); + const res = await chain.call({ question }); console.log(res); /* Ask it a follow up question */ - const chatHistory = question + res.text; const followUpRes = await chain.call({ question: "Was that nice?", - chat_history: chatHistory, }); console.log(followUpRes); }; diff --git a/examples/src/chains/conversational_qa_customization.ts b/examples/src/chains/conversational_qa_customization.ts new file mode 100644 index 000000000000..451b49b2c2d2 --- /dev/null +++ b/examples/src/chains/conversational_qa_customization.ts @@ -0,0 +1,40 @@ +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { ConversationalRetrievalQAChain } from "langchain/chains"; +import { HNSWLib } from "langchain/vectorstores/hnswlib"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import { BufferMemory } from "langchain/memory"; +import * as fs from "fs"; + +export const run = async () => { + const text = fs.readFileSync("state_of_the_union.txt", "utf8"); + const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 }); + const docs = await textSplitter.createDocuments([text]); + const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings()); + const fasterModel = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + }); + const slowerModel = new ChatOpenAI({ + modelName: "gpt-4", + }); + const chain = ConversationalRetrievalQAChain.fromLLM( + slowerModel, + vectorStore.asRetriever(), + { + memory: new BufferMemory({ + memoryKey: "chat_history", + returnMessages: true, + }), + questionGeneratorChainOptions: { + llm: fasterModel, + }, + } + ); + /* Ask it a question */ + const question = "What did the president say about Justice Breyer?"; + const res = await chain.call({ question }); + console.log(res); + + const followUpRes = await chain.call({ question: "Was that nice?" }); + console.log(followUpRes); +}; diff --git a/examples/src/chains/conversational_qa_external_memory.ts b/examples/src/chains/conversational_qa_external_memory.ts new file mode 100644 index 000000000000..1a86dd399ebe --- /dev/null +++ b/examples/src/chains/conversational_qa_external_memory.ts @@ -0,0 +1,34 @@ +import { OpenAI } from "langchain/llms/openai"; +import { ConversationalRetrievalQAChain } from "langchain/chains"; +import { HNSWLib } from "langchain/vectorstores/hnswlib"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import * as fs from "fs"; + +export const run = async () => { + /* Initialize the LLM to use to answer the question */ + const model = new OpenAI({}); + /* Load in the file we want to do question answering over */ + const text = fs.readFileSync("state_of_the_union.txt", "utf8"); + /* Split the text into chunks */ + const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 }); + const docs = await textSplitter.createDocuments([text]); + /* Create the vectorstore */ + const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings()); + /* Create the chain */ + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + vectorStore.asRetriever() + ); + /* Ask it a question */ + const question = "What did the president say about Justice Breyer?"; + const res = await chain.call({ question, chat_history: [] }); + console.log(res); + /* Ask it a follow up question */ + const chatHistory = question + res.text; + const followUpRes = await chain.call({ + question: "Was that nice?", + chat_history: chatHistory, + }); + console.log(followUpRes); +}; diff --git a/langchain/src/chains/conversational_retrieval_chain.ts b/langchain/src/chains/conversational_retrieval_chain.ts index 7c3a141393fe..2d44dd31c329 100644 --- a/langchain/src/chains/conversational_retrieval_chain.ts +++ b/langchain/src/chains/conversational_retrieval_chain.ts @@ -1,7 +1,11 @@ import { PromptTemplate } from "../prompts/prompt.js"; import { BaseLanguageModel } from "../base_language/index.js"; import { SerializedChatVectorDBQAChain } from "./serde.js"; -import { ChainValues, BaseRetriever } from "../schema/index.js"; +import { + ChainValues, + BaseRetriever, + BaseChatMessage, +} from "../schema/index.js"; import { BaseChain, ChainInputs } from "./base.js"; import { LLMChain } from "./llm_chain.js"; import { QAChainParams, loadQAChain } from "./question_answering/load.js"; @@ -17,8 +21,7 @@ Chat History: Follow Up Input: {question} Standalone question:`; -export interface ConversationalRetrievalQAChainInput - extends Omit { +export interface ConversationalRetrievalQAChainInput extends ChainInputs { retriever: BaseRetriever; combineDocumentsChain: BaseChain; questionGeneratorChain: LLMChain; @@ -62,6 +65,23 @@ export class ConversationalRetrievalQAChain fields.returnSourceDocuments ?? this.returnSourceDocuments; } + static getChatHistoryString(chatHistory: string | BaseChatMessage[]) { + if (Array.isArray(chatHistory)) { + return chatHistory + .map((chatMessage) => { + if (chatMessage._getType() === "human") { + return `Human: ${chatMessage.text}`; + } else if (chatMessage._getType() === "ai") { + return `Assistant: ${chatMessage.text}`; + } else { + return `${chatMessage.text}`; + } + }) + .join("\n"); + } + return chatHistory; + } + /** @ignore */ async _call( values: ChainValues, @@ -71,10 +91,13 @@ export class ConversationalRetrievalQAChain throw new Error(`Question key ${this.inputKey} not found.`); } if (!(this.chatHistoryKey in values)) { - throw new Error(`chat history key ${this.inputKey} not found.`); + throw new Error(`Chat history key ${this.chatHistoryKey} not found.`); } const question: string = values[this.inputKey]; - const chatHistory: string = values[this.chatHistoryKey]; + const chatHistory: string = + ConversationalRetrievalQAChain.getChatHistoryString( + values[this.chatHistoryKey] + ); let newQuestion = question; if (chatHistory.length > 0) { const result = await this.questionGeneratorChain.call( diff --git a/langchain/src/chains/tests/conversational_retrieval_chain.int.test.ts b/langchain/src/chains/tests/conversational_retrieval_chain.int.test.ts index 478e6cb50060..dec5c51111b8 100644 --- a/langchain/src/chains/tests/conversational_retrieval_chain.int.test.ts +++ b/langchain/src/chains/tests/conversational_retrieval_chain.int.test.ts @@ -5,6 +5,7 @@ import { HNSWLib } from "../../vectorstores/hnswlib.js"; import { OpenAIEmbeddings } from "../../embeddings/openai.js"; import { ChatOpenAI } from "../../chat_models/openai.js"; import { PromptTemplate } from "../../prompts/index.js"; +import { BufferMemory } from "../../memory/buffer_memory.js"; test("Test ConversationalRetrievalQAChain from LLM", async () => { const model = new OpenAI({ modelName: "text-ada-001" }); @@ -142,3 +143,154 @@ test("Test ConversationalRetrievalQAChain from LLM with a map reduce chain", asy console.log({ res }); }); + +test("Test ConversationalRetrievalQAChain from LLM without memory", async () => { + const model = new OpenAI({ + temperature: 0, + }); + const vectorStore = await HNSWLib.fromTexts( + [ + "Mitochondria are the powerhouse of the cell", + "Foo is red", + "Bar is red", + "Buildings are made out of brick", + "Mitochondria are made of lipids", + ], + [{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], + new OpenAIEmbeddings() + ); + + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + vectorStore.asRetriever() + ); + const question = "What is the powerhouse of the cell?"; + const res = await chain.call({ + question, + chat_history: "", + }); + + console.log({ res }); + + const res2 = await chain.call({ + question: "What are they made out of?", + chat_history: question + res.text, + }); + + console.log({ res2 }); +}); + +test("Test ConversationalRetrievalQAChain from LLM with a chat model without memory", async () => { + const model = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, + }); + const vectorStore = await HNSWLib.fromTexts( + [ + "Mitochondria are the powerhouse of the cell", + "Foo is red", + "Bar is red", + "Buildings are made out of brick", + "Mitochondria are made of lipids", + ], + [{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], + new OpenAIEmbeddings() + ); + + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + vectorStore.asRetriever() + ); + const question = "What is the powerhouse of the cell?"; + const res = await chain.call({ + question, + chat_history: "", + }); + + console.log({ res }); + + const res2 = await chain.call({ + question: "What are they made out of?", + chat_history: question + res.text, + }); + + console.log({ res2 }); +}); + +test("Test ConversationalRetrievalQAChain from LLM with memory", async () => { + const model = new OpenAI({ + temperature: 0, + }); + const vectorStore = await HNSWLib.fromTexts( + [ + "Mitochondria are the powerhouse of the cell", + "Foo is red", + "Bar is red", + "Buildings are made out of brick", + "Mitochondria are made of lipids", + ], + [{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], + new OpenAIEmbeddings() + ); + + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + vectorStore.asRetriever(), + { + memory: new BufferMemory({ + memoryKey: "chat_history", + }), + } + ); + const res = await chain.call({ + question: "What is the powerhouse of the cell?", + }); + + console.log({ res }); + + const res2 = await chain.call({ + question: "What are they made out of?", + }); + + console.log({ res2 }); +}); + +test("Test ConversationalRetrievalQAChain from LLM with a chat model and memory", async () => { + const model = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, + }); + const vectorStore = await HNSWLib.fromTexts( + [ + "Mitochondria are the powerhouse of the cell", + "Foo is red", + "Bar is red", + "Buildings are made out of brick", + "Mitochondria are made of lipids", + ], + [{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], + new OpenAIEmbeddings() + ); + + const chain = ConversationalRetrievalQAChain.fromLLM( + model, + vectorStore.asRetriever(), + { + memory: new BufferMemory({ + memoryKey: "chat_history", + returnMessages: true, + }), + } + ); + const res = await chain.call({ + question: "What is the powerhouse of the cell?", + }); + + console.log({ res }); + + const res2 = await chain.call({ + question: "What are they made out of?", + }); + + console.log({ res2 }); +});