From 905b9df319c222db4c2bdf339e9d81d77ced7831 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Sun, 4 Jun 2023 15:47:15 -0700 Subject: [PATCH] Adds pipeline prompt (#1526) * Adds pipeline prompt * Adds test for .partial * Fix formatting * Update pipeline prompt type parameter name, adds docs * Fix example --- .../prompt_templates/prompt_composition.mdx | 23 ++++ examples/src/prompts/pipeline_prompt.ts | 59 +++++++++ langchain/src/prompts/base.ts | 2 + langchain/src/prompts/chat.ts | 4 +- langchain/src/prompts/index.ts | 5 + langchain/src/prompts/pipeline.ts | 120 ++++++++++++++++++ langchain/src/prompts/tests/pipeline.test.ts | 120 ++++++++++++++++++ 7 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 docs/docs/modules/prompts/prompt_templates/prompt_composition.mdx create mode 100644 examples/src/prompts/pipeline_prompt.ts create mode 100644 langchain/src/prompts/pipeline.ts create mode 100644 langchain/src/prompts/tests/pipeline.test.ts diff --git a/docs/docs/modules/prompts/prompt_templates/prompt_composition.mdx b/docs/docs/modules/prompts/prompt_templates/prompt_composition.mdx new file mode 100644 index 000000000000..b715de2f5d88 --- /dev/null +++ b/docs/docs/modules/prompts/prompt_templates/prompt_composition.mdx @@ -0,0 +1,23 @@ +--- +hide_table_of_contents: true +sidebar_position: 1 +--- + +import CodeBlock from "@theme/CodeBlock"; + +# Prompt Composition + +Pipeline prompt templates allow you to compose multiple individual prompt templates together. +This can be useful when you want to reuse parts of individual prompts. + +Rather than taking `inputVariables` as an argument, pipeline prompt templates take two new arguments: + +- `pipelinePrompts`: An array of objects containing a string (`name`) and a `PromptTemplate` passed in as `promptTemplate`. + Each `PromptTemplate` will be formatted and then passed to the next prompt template in the pipeline as an input variable with the same name as `name`. +- `finalPrompt`: The final prompt that will be returned. + +Here's an example of what this looks like in action: + +import Example from "@examples/prompts/pipeline_prompt.ts"; + +{Example} diff --git a/examples/src/prompts/pipeline_prompt.ts b/examples/src/prompts/pipeline_prompt.ts new file mode 100644 index 000000000000..2ea7cf412e6a --- /dev/null +++ b/examples/src/prompts/pipeline_prompt.ts @@ -0,0 +1,59 @@ +import { PromptTemplate, PipelinePromptTemplate } from "langchain/prompts"; + +const fullPrompt = PromptTemplate.fromTemplate(`{introduction} + +{example} + +{start}`); + +const introductionPrompt = PromptTemplate.fromTemplate( + `You are impersonating {person}.` +); + +const examplePrompt = + PromptTemplate.fromTemplate(`Here's an example of an interaction: +Q: {example_q} +A: {example_a}`); + +const startPrompt = PromptTemplate.fromTemplate(`Now, do this for real! +Q: {input} +A:`); + +const composedPrompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "introduction", + prompt: introductionPrompt, + }, + { + name: "example", + prompt: examplePrompt, + }, + { + name: "start", + prompt: startPrompt, + }, + ], + finalPrompt: fullPrompt, +}); + +const formattedPrompt = await composedPrompt.format({ + person: "Elon Musk", + example_q: `What's your favorite car?`, + example_a: "Telsa", + input: `What's your favorite social media site?`, +}); + +console.log(formattedPrompt); + +/* + You are impersonating Elon Musk. + + Here's an example of an interaction: + Q: What's your favorite car? + A: Telsa + + Now, do this for real! + Q: What's your favorite social media site? + A: +*/ diff --git a/langchain/src/prompts/base.ts b/langchain/src/prompts/base.ts index 761b3aea9308..8df4d858409d 100644 --- a/langchain/src/prompts/base.ts +++ b/langchain/src/prompts/base.ts @@ -48,6 +48,8 @@ export interface BasePromptTemplateInput { * string prompt given a set of input values. */ export abstract class BasePromptTemplate implements BasePromptTemplateInput { + declare PromptValueReturnType: BasePromptValue; + inputVariables: string[]; outputParser?: BaseOutputParser; diff --git a/langchain/src/prompts/chat.ts b/langchain/src/prompts/chat.ts index 2d28b1a00617..d7f21fd0050c 100644 --- a/langchain/src/prompts/chat.ts +++ b/langchain/src/prompts/chat.ts @@ -86,6 +86,8 @@ export abstract class BaseMessageStringPromptTemplate extends BaseMessagePromptT } export abstract class BaseChatPromptTemplate extends BasePromptTemplate { + declare PromptValueReturnType: ChatPromptValue; + constructor(input: BasePromptTemplateInput) { super(input); } @@ -96,7 +98,7 @@ export abstract class BaseChatPromptTemplate extends BasePromptTemplate { return (await this.formatPromptValue(values)).toString(); } - async formatPromptValue(values: InputValues): Promise { + async formatPromptValue(values: InputValues): Promise { const resultMessages = await this.formatMessages(values); return new ChatPromptValue(resultMessages); } diff --git a/langchain/src/prompts/index.ts b/langchain/src/prompts/index.ts index 6f02084a5ed4..61b2b281ebe2 100644 --- a/langchain/src/prompts/index.ts +++ b/langchain/src/prompts/index.ts @@ -46,3 +46,8 @@ export { checkValidTemplate, TemplateFormat, } from "./template.js"; +export { + PipelinePromptParams, + PipelinePromptTemplate, + PipelinePromptTemplateInput, +} from "./pipeline.js"; diff --git a/langchain/src/prompts/pipeline.ts b/langchain/src/prompts/pipeline.ts new file mode 100644 index 000000000000..be7a93a98c1f --- /dev/null +++ b/langchain/src/prompts/pipeline.ts @@ -0,0 +1,120 @@ +import { InputValues, PartialValues } from "../schema/index.js"; +import { BasePromptTemplate, BasePromptTemplateInput } from "./base.js"; +import { ChatPromptTemplate } from "./chat.js"; +import { SerializedBasePromptTemplate } from "./serde.js"; + +export type PipelinePromptParams< + PromptTemplateType extends BasePromptTemplate +> = { + name: string; + prompt: PromptTemplateType; +}; + +export type PipelinePromptTemplateInput< + PromptTemplateType extends BasePromptTemplate +> = Omit & { + pipelinePrompts: PipelinePromptParams[]; + finalPrompt: PromptTemplateType; +}; + +export class PipelinePromptTemplate< + PromptTemplateType extends BasePromptTemplate +> extends BasePromptTemplate { + pipelinePrompts: PipelinePromptParams[]; + + finalPrompt: PromptTemplateType; + + constructor(input: PipelinePromptTemplateInput) { + super({ ...input, inputVariables: [] }); + this.pipelinePrompts = input.pipelinePrompts; + this.finalPrompt = input.finalPrompt; + this.inputVariables = this.computeInputValues(); + } + + protected computeInputValues() { + const intermediateValues = this.pipelinePrompts.map( + (pipelinePrompt) => pipelinePrompt.name + ); + const inputValues = this.pipelinePrompts + .map((pipelinePrompt) => + pipelinePrompt.prompt.inputVariables.filter( + (inputValue) => !intermediateValues.includes(inputValue) + ) + ) + .flat(); + return [...new Set(inputValues)]; + } + + protected static extractRequiredInputValues( + allValues: InputValues, + requiredValueNames: string[] + ) { + return requiredValueNames.reduce((requiredValues, valueName) => { + // eslint-disable-next-line no-param-reassign + requiredValues[valueName] = allValues[valueName]; + return requiredValues; + }, {} as InputValues); + } + + protected async formatPipelinePrompts( + values: InputValues + ): Promise { + const allValues = await this.mergePartialAndUserVariables(values); + for (const { name: pipelinePromptName, prompt: pipelinePrompt } of this + .pipelinePrompts) { + const pipelinePromptInputValues = + PipelinePromptTemplate.extractRequiredInputValues( + allValues, + pipelinePrompt.inputVariables + ); + // eslint-disable-next-line no-instanceof/no-instanceof + if (pipelinePrompt instanceof ChatPromptTemplate) { + allValues[pipelinePromptName] = await pipelinePrompt.formatMessages( + pipelinePromptInputValues + ); + } else { + allValues[pipelinePromptName] = await pipelinePrompt.format( + pipelinePromptInputValues + ); + } + } + return PipelinePromptTemplate.extractRequiredInputValues( + allValues, + this.finalPrompt.inputVariables + ); + } + + async formatPromptValue( + values: InputValues + ): Promise { + return this.finalPrompt.formatPromptValue( + await this.formatPipelinePrompts(values) + ); + } + + async format(values: InputValues): Promise { + return this.finalPrompt.format(await this.formatPipelinePrompts(values)); + } + + async partial( + values: PartialValues + ): Promise> { + const promptDict = { ...this }; + promptDict.inputVariables = this.inputVariables.filter( + (iv) => !(iv in values) + ); + promptDict.partialVariables = { + ...(this.partialVariables ?? {}), + ...values, + }; + return new PipelinePromptTemplate(promptDict); + } + + serialize(): SerializedBasePromptTemplate { + throw new Error("Not implemented."); + } + + _getPromptType(): string { + return "pipeline"; + } +} diff --git a/langchain/src/prompts/tests/pipeline.test.ts b/langchain/src/prompts/tests/pipeline.test.ts new file mode 100644 index 000000000000..e04109d42e91 --- /dev/null +++ b/langchain/src/prompts/tests/pipeline.test.ts @@ -0,0 +1,120 @@ +import { expect, test } from "@jest/globals"; +import { PromptTemplate } from "../prompt.js"; +import { + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, +} from "../chat.js"; +import { PipelinePromptTemplate } from "../pipeline.js"; + +test("Test pipeline input variables", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "bar", + prompt: PromptTemplate.fromTemplate("{foo}"), + }, + ], + finalPrompt: PromptTemplate.fromTemplate("{bar}"), + }); + expect(prompt.inputVariables).toEqual(["foo"]); +}); + +test("Test simple pipeline", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "bar", + prompt: PromptTemplate.fromTemplate("{foo}"), + }, + ], + finalPrompt: PromptTemplate.fromTemplate("{bar}"), + }); + expect( + await prompt.format({ + foo: "jim", + }) + ).toEqual("jim"); +}); + +test("Test multi variable pipeline", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "bar", + prompt: PromptTemplate.fromTemplate("{foo}"), + }, + ], + finalPrompt: PromptTemplate.fromTemplate("okay {bar} {baz}"), + }); + expect( + await prompt.format({ + foo: "jim", + baz: "halpert", + }) + ).toEqual("okay jim halpert"); +}); + +test("Test longer pipeline", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "bar", + prompt: PromptTemplate.fromTemplate("{foo}"), + }, + { + name: "qux", + prompt: PromptTemplate.fromTemplate("hi {bar}"), + }, + ], + finalPrompt: PromptTemplate.fromTemplate("okay {qux} {baz}"), + }); + expect( + await prompt.format({ + foo: "pam", + baz: "beasley", + }) + ).toEqual("okay hi pam beasley"); +}); + +test("Test with .partial", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "bar", + prompt: PromptTemplate.fromTemplate("{foo}"), + }, + ], + finalPrompt: PromptTemplate.fromTemplate("okay {bar} {baz}"), + }); + const partialPrompt = await prompt.partial({ + baz: "schrute", + }); + expect( + await partialPrompt.format({ + foo: "dwight", + }) + ).toEqual("okay dwight schrute"); +}); + +test("Test with chat prompts", async () => { + const prompt = new PipelinePromptTemplate({ + pipelinePrompts: [ + { + name: "foo", + prompt: ChatPromptTemplate.fromPromptMessages([ + HumanMessagePromptTemplate.fromTemplate(`{name} halpert`), + ]), + }, + ], + finalPrompt: ChatPromptTemplate.fromPromptMessages([ + SystemMessagePromptTemplate.fromTemplate("What is your name?"), + new MessagesPlaceholder("foo"), + ]), + }); + const formattedPromptValue = await prompt.formatPromptValue({ + name: "pam", + }); + expect(formattedPromptValue.messages[1].text).toEqual("pam halpert"); +});