Skip to content

Commit

Permalink
Adds support for built-in memory to ConversationalQARetrievalChain (#…
Browse files Browse the repository at this point in the history
…1463)

* Adds support for built-in memory to ConversationalQARetrievalChain

* Fix typo in docs

* More docs updates

* Fix typo in docs
  • Loading branch information
jacoblee93 authored May 30, 2023
1 parent 8e7fcc8 commit b780ab4
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import ConvoRetrievalQAExample from "@examples/chains/conversational_qa.ts";

The `ConversationalRetrievalQA` chain builds on `RetrievalQAChain` to provide a chat history component.

It requires two inputs: a question and the chat history. It first combines the chat history and the question into a standalone question, then looks up relevant documents from the retriever, and then passes those documents and the question to a question answering chain to return a response.
It first combines the chat history (either explicitly passed in or retrieved from the provided memory) and the question into a standalone question, then looks up relevant documents from the retriever, and finally passes those documents and the question to a question answering chain to return a response.

To create one, you will need a retriever. In the below example, we will create one from a vectorstore, which can be created from embeddings.
To create one, you will need a retriever. In the below example, we will create one from a vector store, which can be created from embeddings.

import Example from "@examples/chains/conversational_qa.ts";

<CodeBlock language="typescript">{ConvoRetrievalQAExample}</CodeBlock>

In this code snippet, the fromLLM method of the `ConversationalRetrievalQAChain` class has the following signature:
In the above code snippet, the fromLLM method of the `ConversationalRetrievalQAChain` class has the following signature:

```typescript
static fromLLM(
Expand Down Expand Up @@ -46,4 +46,17 @@ Here's an explanation of each of the attributes of the options object:
You can see [documentation about the usable fields here](/docs/api/chains/types/QAChainParams).
- `returnSourceDocuments`: A boolean value that indicates whether the `ConversationalRetrievalQAChain` should return the source documents that were used to retrieve the answer. If set to true, the documents will be included in the result returned by the call() method. This can be useful if you want to allow the user to see the sources used to generate the answer. If not set, the default value will be false.

In summary, the `questionGeneratorChainOptions`, `qaChainOptions`, and `returnSourceDocuments` options allow the user to customize the behavior of the `ConversationalRetrievalQAChain`
Here's a customization example using a faster LLM to generate questions and a slower, more comprehensive LLM for the final answer:

import ConvoQACustomizationExample from "@examples/chains/conversational_qa_customization.ts";

<CodeBlock language="typescript">{ConvoQACustomizationExample}</CodeBlock>

## External Memory

If you'd like to format the chat history in a specific way, you can also pass the chat history in explicitly by omitting the `memory` option and passing in
a `chat_history` string directly into the `chain.call` method:

import ConvoQAExternalMemoryExample from "@examples/chains/conversational_qa_external_memory.ts";

<CodeBlock language="typescript">{ConvoQAExternalMemoryExample}</CodeBlock>
12 changes: 8 additions & 4 deletions examples/src/chains/conversational_qa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ConversationalRetrievalQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { BufferMemory } from "langchain/memory";
import * as fs from "fs";

export const run = async () => {
Expand All @@ -18,17 +19,20 @@ export const run = async () => {
/* Create the chain */
const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever()
vectorStore.asRetriever(),
{
memory: new BufferMemory({
memoryKey: "chat_history", // Must be set to "chat_history"
}),
}
);
/* Ask it a question */
const question = "What did the president say about Justice Breyer?";
const res = await chain.call({ question, chat_history: [] });
const res = await chain.call({ question });
console.log(res);
/* Ask it a follow up question */
const chatHistory = question + res.text;
const followUpRes = await chain.call({
question: "Was that nice?",
chat_history: chatHistory,
});
console.log(followUpRes);
};
40 changes: 40 additions & 0 deletions examples/src/chains/conversational_qa_customization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { ChatOpenAI } from "langchain/chat_models/openai";
import { ConversationalRetrievalQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { BufferMemory } from "langchain/memory";
import * as fs from "fs";

export const run = async () => {
const text = fs.readFileSync("state_of_the_union.txt", "utf8");
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([text]);
const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings());
const fasterModel = new ChatOpenAI({
modelName: "gpt-3.5-turbo",
});
const slowerModel = new ChatOpenAI({
modelName: "gpt-4",
});
const chain = ConversationalRetrievalQAChain.fromLLM(
slowerModel,
vectorStore.asRetriever(),
{
memory: new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
}),
questionGeneratorChainOptions: {
llm: fasterModel,
},
}
);
/* Ask it a question */
const question = "What did the president say about Justice Breyer?";
const res = await chain.call({ question });
console.log(res);

const followUpRes = await chain.call({ question: "Was that nice?" });
console.log(followUpRes);
};
34 changes: 34 additions & 0 deletions examples/src/chains/conversational_qa_external_memory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { OpenAI } from "langchain/llms/openai";
import { ConversationalRetrievalQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import * as fs from "fs";

export const run = async () => {
/* Initialize the LLM to use to answer the question */
const model = new OpenAI({});
/* Load in the file we want to do question answering over */
const text = fs.readFileSync("state_of_the_union.txt", "utf8");
/* Split the text into chunks */
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([text]);
/* Create the vectorstore */
const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings());
/* Create the chain */
const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever()
);
/* Ask it a question */
const question = "What did the president say about Justice Breyer?";
const res = await chain.call({ question, chat_history: [] });
console.log(res);
/* Ask it a follow up question */
const chatHistory = question + res.text;
const followUpRes = await chain.call({
question: "Was that nice?",
chat_history: chatHistory,
});
console.log(followUpRes);
};
33 changes: 28 additions & 5 deletions langchain/src/chains/conversational_retrieval_chain.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { PromptTemplate } from "../prompts/prompt.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { SerializedChatVectorDBQAChain } from "./serde.js";
import { ChainValues, BaseRetriever } from "../schema/index.js";
import {
ChainValues,
BaseRetriever,
BaseChatMessage,
} from "../schema/index.js";
import { BaseChain, ChainInputs } from "./base.js";
import { LLMChain } from "./llm_chain.js";
import { QAChainParams, loadQAChain } from "./question_answering/load.js";
Expand All @@ -17,8 +21,7 @@ Chat History:
Follow Up Input: {question}
Standalone question:`;

export interface ConversationalRetrievalQAChainInput
extends Omit<ChainInputs, "memory"> {
export interface ConversationalRetrievalQAChainInput extends ChainInputs {
retriever: BaseRetriever;
combineDocumentsChain: BaseChain;
questionGeneratorChain: LLMChain;
Expand Down Expand Up @@ -62,6 +65,23 @@ export class ConversationalRetrievalQAChain
fields.returnSourceDocuments ?? this.returnSourceDocuments;
}

static getChatHistoryString(chatHistory: string | BaseChatMessage[]) {
if (Array.isArray(chatHistory)) {
return chatHistory
.map((chatMessage) => {
if (chatMessage._getType() === "human") {
return `Human: ${chatMessage.text}`;
} else if (chatMessage._getType() === "ai") {
return `Assistant: ${chatMessage.text}`;
} else {
return `${chatMessage.text}`;
}
})
.join("\n");
}
return chatHistory;
}

/** @ignore */
async _call(
values: ChainValues,
Expand All @@ -71,10 +91,13 @@ export class ConversationalRetrievalQAChain
throw new Error(`Question key ${this.inputKey} not found.`);
}
if (!(this.chatHistoryKey in values)) {
throw new Error(`chat history key ${this.inputKey} not found.`);
throw new Error(`Chat history key ${this.chatHistoryKey} not found.`);
}
const question: string = values[this.inputKey];
const chatHistory: string = values[this.chatHistoryKey];
const chatHistory: string =
ConversationalRetrievalQAChain.getChatHistoryString(
values[this.chatHistoryKey]
);
let newQuestion = question;
if (chatHistory.length > 0) {
const result = await this.questionGeneratorChain.call(
Expand Down
152 changes: 152 additions & 0 deletions langchain/src/chains/tests/conversational_retrieval_chain.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { HNSWLib } from "../../vectorstores/hnswlib.js";
import { OpenAIEmbeddings } from "../../embeddings/openai.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { PromptTemplate } from "../../prompts/index.js";
import { BufferMemory } from "../../memory/buffer_memory.js";

test("Test ConversationalRetrievalQAChain from LLM", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
Expand Down Expand Up @@ -142,3 +143,154 @@ test("Test ConversationalRetrievalQAChain from LLM with a map reduce chain", asy

console.log({ res });
});

test("Test ConversationalRetrievalQAChain from LLM without memory", async () => {
const model = new OpenAI({
temperature: 0,
});
const vectorStore = await HNSWLib.fromTexts(
[
"Mitochondria are the powerhouse of the cell",
"Foo is red",
"Bar is red",
"Buildings are made out of brick",
"Mitochondria are made of lipids",
],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever()
);
const question = "What is the powerhouse of the cell?";
const res = await chain.call({
question,
chat_history: "",
});

console.log({ res });

const res2 = await chain.call({
question: "What are they made out of?",
chat_history: question + res.text,
});

console.log({ res2 });
});

test("Test ConversationalRetrievalQAChain from LLM with a chat model without memory", async () => {
const model = new ChatOpenAI({
modelName: "gpt-3.5-turbo",
temperature: 0,
});
const vectorStore = await HNSWLib.fromTexts(
[
"Mitochondria are the powerhouse of the cell",
"Foo is red",
"Bar is red",
"Buildings are made out of brick",
"Mitochondria are made of lipids",
],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever()
);
const question = "What is the powerhouse of the cell?";
const res = await chain.call({
question,
chat_history: "",
});

console.log({ res });

const res2 = await chain.call({
question: "What are they made out of?",
chat_history: question + res.text,
});

console.log({ res2 });
});

test("Test ConversationalRetrievalQAChain from LLM with memory", async () => {
const model = new OpenAI({
temperature: 0,
});
const vectorStore = await HNSWLib.fromTexts(
[
"Mitochondria are the powerhouse of the cell",
"Foo is red",
"Bar is red",
"Buildings are made out of brick",
"Mitochondria are made of lipids",
],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever(),
{
memory: new BufferMemory({
memoryKey: "chat_history",
}),
}
);
const res = await chain.call({
question: "What is the powerhouse of the cell?",
});

console.log({ res });

const res2 = await chain.call({
question: "What are they made out of?",
});

console.log({ res2 });
});

test("Test ConversationalRetrievalQAChain from LLM with a chat model and memory", async () => {
const model = new ChatOpenAI({
modelName: "gpt-3.5-turbo",
temperature: 0,
});
const vectorStore = await HNSWLib.fromTexts(
[
"Mitochondria are the powerhouse of the cell",
"Foo is red",
"Bar is red",
"Buildings are made out of brick",
"Mitochondria are made of lipids",
],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
);

const chain = ConversationalRetrievalQAChain.fromLLM(
model,
vectorStore.asRetriever(),
{
memory: new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
}),
}
);
const res = await chain.call({
question: "What is the powerhouse of the cell?",
});

console.log({ res });

const res2 = await chain.call({
question: "What are they made out of?",
});

console.log({ res2 });
});

1 comment on commit b780ab4

@vercel
Copy link

@vercel vercel bot commented on b780ab4 May 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.