Skip to content

Commit

Permalink
Adds pipeline prompt (#1526)
Browse files Browse the repository at this point in the history
* Adds pipeline prompt

* Adds test for .partial

* Fix formatting

* Update pipeline prompt type parameter name, adds docs

* Fix example
  • Loading branch information
jacoblee93 authored Jun 4, 2023
1 parent 37d615a commit 905b9df
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 1 deletion.
23 changes: 23 additions & 0 deletions docs/docs/modules/prompts/prompt_templates/prompt_composition.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
hide_table_of_contents: true
sidebar_position: 1
---

import CodeBlock from "@theme/CodeBlock";

# Prompt Composition

Pipeline prompt templates allow you to compose multiple individual prompt templates together.
This can be useful when you want to reuse parts of individual prompts.

Rather than taking `inputVariables` as an argument, pipeline prompt templates take two new arguments:

- `pipelinePrompts`: An array of objects containing a string (`name`) and a `PromptTemplate` passed in as `promptTemplate`.
Each `PromptTemplate` will be formatted and then passed to the next prompt template in the pipeline as an input variable with the same name as `name`.
- `finalPrompt`: The final prompt that will be returned.

Here's an example of what this looks like in action:

import Example from "@examples/prompts/pipeline_prompt.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>
59 changes: 59 additions & 0 deletions examples/src/prompts/pipeline_prompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { PromptTemplate, PipelinePromptTemplate } from "langchain/prompts";

const fullPrompt = PromptTemplate.fromTemplate(`{introduction}
{example}
{start}`);

const introductionPrompt = PromptTemplate.fromTemplate(
`You are impersonating {person}.`
);

const examplePrompt =
PromptTemplate.fromTemplate(`Here's an example of an interaction:
Q: {example_q}
A: {example_a}`);

const startPrompt = PromptTemplate.fromTemplate(`Now, do this for real!
Q: {input}
A:`);

const composedPrompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "introduction",
prompt: introductionPrompt,
},
{
name: "example",
prompt: examplePrompt,
},
{
name: "start",
prompt: startPrompt,
},
],
finalPrompt: fullPrompt,
});

const formattedPrompt = await composedPrompt.format({
person: "Elon Musk",
example_q: `What's your favorite car?`,
example_a: "Telsa",
input: `What's your favorite social media site?`,
});

console.log(formattedPrompt);

/*
You are impersonating Elon Musk.
Here's an example of an interaction:
Q: What's your favorite car?
A: Telsa
Now, do this for real!
Q: What's your favorite social media site?
A:
*/
2 changes: 2 additions & 0 deletions langchain/src/prompts/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export interface BasePromptTemplateInput {
* string prompt given a set of input values.
*/
export abstract class BasePromptTemplate implements BasePromptTemplateInput {
declare PromptValueReturnType: BasePromptValue;

inputVariables: string[];

outputParser?: BaseOutputParser;
Expand Down
4 changes: 3 additions & 1 deletion langchain/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ export abstract class BaseMessageStringPromptTemplate extends BaseMessagePromptT
}

export abstract class BaseChatPromptTemplate extends BasePromptTemplate {
declare PromptValueReturnType: ChatPromptValue;

constructor(input: BasePromptTemplateInput) {
super(input);
}
Expand All @@ -96,7 +98,7 @@ export abstract class BaseChatPromptTemplate extends BasePromptTemplate {
return (await this.formatPromptValue(values)).toString();
}

async formatPromptValue(values: InputValues): Promise<BasePromptValue> {
async formatPromptValue(values: InputValues): Promise<ChatPromptValue> {
const resultMessages = await this.formatMessages(values);
return new ChatPromptValue(resultMessages);
}
Expand Down
5 changes: 5 additions & 0 deletions langchain/src/prompts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ export {
checkValidTemplate,
TemplateFormat,
} from "./template.js";
export {
PipelinePromptParams,
PipelinePromptTemplate,
PipelinePromptTemplateInput,
} from "./pipeline.js";
120 changes: 120 additions & 0 deletions langchain/src/prompts/pipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { InputValues, PartialValues } from "../schema/index.js";
import { BasePromptTemplate, BasePromptTemplateInput } from "./base.js";
import { ChatPromptTemplate } from "./chat.js";
import { SerializedBasePromptTemplate } from "./serde.js";

export type PipelinePromptParams<
PromptTemplateType extends BasePromptTemplate
> = {
name: string;
prompt: PromptTemplateType;
};

export type PipelinePromptTemplateInput<
PromptTemplateType extends BasePromptTemplate
> = Omit<BasePromptTemplateInput, "inputVariables"> & {
pipelinePrompts: PipelinePromptParams<PromptTemplateType>[];
finalPrompt: PromptTemplateType;
};

export class PipelinePromptTemplate<
PromptTemplateType extends BasePromptTemplate
> extends BasePromptTemplate {
pipelinePrompts: PipelinePromptParams<PromptTemplateType>[];

finalPrompt: PromptTemplateType;

constructor(input: PipelinePromptTemplateInput<PromptTemplateType>) {
super({ ...input, inputVariables: [] });
this.pipelinePrompts = input.pipelinePrompts;
this.finalPrompt = input.finalPrompt;
this.inputVariables = this.computeInputValues();
}

protected computeInputValues() {
const intermediateValues = this.pipelinePrompts.map(
(pipelinePrompt) => pipelinePrompt.name
);
const inputValues = this.pipelinePrompts
.map((pipelinePrompt) =>
pipelinePrompt.prompt.inputVariables.filter(
(inputValue) => !intermediateValues.includes(inputValue)
)
)
.flat();
return [...new Set(inputValues)];
}

protected static extractRequiredInputValues(
allValues: InputValues,
requiredValueNames: string[]
) {
return requiredValueNames.reduce((requiredValues, valueName) => {
// eslint-disable-next-line no-param-reassign
requiredValues[valueName] = allValues[valueName];
return requiredValues;
}, {} as InputValues);
}

protected async formatPipelinePrompts(
values: InputValues
): Promise<InputValues> {
const allValues = await this.mergePartialAndUserVariables(values);
for (const { name: pipelinePromptName, prompt: pipelinePrompt } of this
.pipelinePrompts) {
const pipelinePromptInputValues =
PipelinePromptTemplate.extractRequiredInputValues(
allValues,
pipelinePrompt.inputVariables
);
// eslint-disable-next-line no-instanceof/no-instanceof
if (pipelinePrompt instanceof ChatPromptTemplate) {
allValues[pipelinePromptName] = await pipelinePrompt.formatMessages(
pipelinePromptInputValues
);
} else {
allValues[pipelinePromptName] = await pipelinePrompt.format(
pipelinePromptInputValues
);
}
}
return PipelinePromptTemplate.extractRequiredInputValues(
allValues,
this.finalPrompt.inputVariables
);
}

async formatPromptValue(
values: InputValues
): Promise<PromptTemplateType["PromptValueReturnType"]> {
return this.finalPrompt.formatPromptValue(
await this.formatPipelinePrompts(values)
);
}

async format(values: InputValues): Promise<string> {
return this.finalPrompt.format(await this.formatPipelinePrompts(values));
}

async partial(
values: PartialValues
): Promise<PipelinePromptTemplate<PromptTemplateType>> {
const promptDict = { ...this };
promptDict.inputVariables = this.inputVariables.filter(
(iv) => !(iv in values)
);
promptDict.partialVariables = {
...(this.partialVariables ?? {}),
...values,
};
return new PipelinePromptTemplate<PromptTemplateType>(promptDict);
}

serialize(): SerializedBasePromptTemplate {
throw new Error("Not implemented.");
}

_getPromptType(): string {
return "pipeline";
}
}
120 changes: 120 additions & 0 deletions langchain/src/prompts/tests/pipeline.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { expect, test } from "@jest/globals";
import { PromptTemplate } from "../prompt.js";
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
} from "../chat.js";
import { PipelinePromptTemplate } from "../pipeline.js";

test("Test pipeline input variables", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "bar",
prompt: PromptTemplate.fromTemplate("{foo}"),
},
],
finalPrompt: PromptTemplate.fromTemplate("{bar}"),
});
expect(prompt.inputVariables).toEqual(["foo"]);
});

test("Test simple pipeline", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "bar",
prompt: PromptTemplate.fromTemplate("{foo}"),
},
],
finalPrompt: PromptTemplate.fromTemplate("{bar}"),
});
expect(
await prompt.format({
foo: "jim",
})
).toEqual("jim");
});

test("Test multi variable pipeline", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "bar",
prompt: PromptTemplate.fromTemplate("{foo}"),
},
],
finalPrompt: PromptTemplate.fromTemplate("okay {bar} {baz}"),
});
expect(
await prompt.format({
foo: "jim",
baz: "halpert",
})
).toEqual("okay jim halpert");
});

test("Test longer pipeline", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "bar",
prompt: PromptTemplate.fromTemplate("{foo}"),
},
{
name: "qux",
prompt: PromptTemplate.fromTemplate("hi {bar}"),
},
],
finalPrompt: PromptTemplate.fromTemplate("okay {qux} {baz}"),
});
expect(
await prompt.format({
foo: "pam",
baz: "beasley",
})
).toEqual("okay hi pam beasley");
});

test("Test with .partial", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "bar",
prompt: PromptTemplate.fromTemplate("{foo}"),
},
],
finalPrompt: PromptTemplate.fromTemplate("okay {bar} {baz}"),
});
const partialPrompt = await prompt.partial({
baz: "schrute",
});
expect(
await partialPrompt.format({
foo: "dwight",
})
).toEqual("okay dwight schrute");
});

test("Test with chat prompts", async () => {
const prompt = new PipelinePromptTemplate({
pipelinePrompts: [
{
name: "foo",
prompt: ChatPromptTemplate.fromPromptMessages([
HumanMessagePromptTemplate.fromTemplate(`{name} halpert`),
]),
},
],
finalPrompt: ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate("What is your name?"),
new MessagesPlaceholder("foo"),
]),
});
const formattedPromptValue = await prompt.formatPromptValue({
name: "pam",
});
expect(formattedPromptValue.messages[1].text).toEqual("pam halpert");
});

1 comment on commit 905b9df

@vercel
Copy link

@vercel vercel bot commented on 905b9df Jun 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.