Skip to content

Commit

Permalink
Allow runnables to implement transform streaming (#2156)
Browse files Browse the repository at this point in the history
* Adds support for transform streaming on runnables

* Adds comment

* Fix types

* Adds EncodingOutputParser

* Rename to BytesOutputParser
  • Loading branch information
jacoblee93 authored Aug 4, 2023
1 parent 1c1274d commit ae9895f
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 16 deletions.
14 changes: 7 additions & 7 deletions langchain/src/chains/transform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ export interface TransformChainFields<
outputVariables: (keyof O extends string ? keyof O : never)[];
}

export class TransformChain<I extends ChainValues, O extends ChainValues>
extends BaseChain
implements TransformChainFields<I, O>
{
transform: (values: I, callbacks?: Callbacks) => O | Promise<O>;
export class TransformChain<
I extends ChainValues,
O extends ChainValues
> extends BaseChain {
transformFunc: (values: I, callbacks?: Callbacks) => O | Promise<O>;

inputVariables: (keyof I extends string ? keyof I : never)[];

Expand All @@ -35,12 +35,12 @@ export class TransformChain<I extends ChainValues, O extends ChainValues>

constructor(fields: TransformChainFields<I, O>) {
super(fields);
this.transform = fields.transform;
this.transformFunc = fields.transform;
this.inputVariables = fields.inputVariables;
this.outputVariables = fields.outputVariables;
}

async _call(values: I, runManager?: CallbackManagerForChainRun): Promise<O> {
return this.transform(values, runManager?.getChild("transform"));
return this.transformFunc(values, runManager?.getChild("transform"));
}
}
17 changes: 14 additions & 3 deletions langchain/src/output_parsers/noop.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
import { StringOutputParser } from "../schema/output_parser.js";
import { BaseOutputParser } from "../schema/output_parser.js";

/** @deprecated Use StringOutputParser instead */
export class NoOpOutputParser extends StringOutputParser {}
export class NoOpOutputParser extends BaseOutputParser<string> {
lc_namespace = ["langchain", "output_parsers", "default"];

lc_serializable = true;

parse(text: string): Promise<string> {
return Promise.resolve(text);
}

getFormatInstructions(): string {
return "";
}
}
53 changes: 51 additions & 2 deletions langchain/src/schema/output_parser.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Callbacks } from "../callbacks/manager.js";
import { BaseCallbackConfig, Callbacks } from "../callbacks/manager.js";
import {
BasePromptValue,
Generation,
Expand Down Expand Up @@ -101,10 +101,39 @@ export abstract class BaseOutputParser<
}
}

/**
* Class to parse the output of an LLM call that also allows streaming inputs.
*/
export abstract class BaseTransformOutputParser<
T = unknown
> extends BaseOutputParser<T> {
async *_transform(
inputGenerator: AsyncGenerator<string | BaseMessage>
): AsyncGenerator<T> {
for await (const chunk of inputGenerator) {
if (typeof chunk === "string") {
yield this.parseResult([{ text: chunk }]);
} else {
yield this.parseResult([{ message: chunk, text: chunk.content }]);
}
}
}

async *transform(
inputGenerator: AsyncGenerator<string | BaseMessage>,
options: BaseCallbackConfig
): AsyncGenerator<T> {
yield* this._streamWithConfig(this._transform(inputGenerator), {
...options,
runType: "parser",
});
}
}

/**
* OutputParser that parses LLMResult into the top likely string.
*/
export class StringOutputParser extends BaseOutputParser<string> {
export class StringOutputParser extends BaseTransformOutputParser<string> {
lc_namespace = ["schema", "output_parser"];

lc_serializable = true;
Expand All @@ -118,6 +147,26 @@ export class StringOutputParser extends BaseOutputParser<string> {
}
}

/**
* OutputParser that parses LLMResult into the top likely string and
* encodes it into bytes.
*/
export class BytesOutputParser extends BaseTransformOutputParser<Uint8Array> {
lc_namespace = ["schema", "output_parser"];

lc_serializable = true;

protected textEncoder = new TextEncoder();

parse(text: string): Promise<Uint8Array> {
return Promise.resolve(this.textEncoder.encode(text));
}

getFormatInstructions(): string {
return "";
}
}

export class OutputParserException extends Error {
output?: string;

Expand Down
71 changes: 68 additions & 3 deletions langchain/src/schema/runnable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,50 @@ export abstract class Runnable<
return output;
}

protected async *_streamWithConfig<T extends RunOutput>(
generator: AsyncGenerator<T>,
options?: RunnableConfig & { runType?: string }
) {
const callbackManager_ = await CallbackManager.configure(
options?.callbacks,
undefined,
options?.tags,
undefined,
options?.metadata
);
// TODO: Find a way to pass the entire streamed value into the callback.
const runManager = await callbackManager_?.handleChainStart(
this.toJSON(),
_coerceToDict("<streamed value>", "input"),
undefined,
options?.runType
);
let output;
let concatSupported = true;
try {
for await (const chunk of generator) {
yield chunk;
if (concatSupported) {
if (output === undefined) {
output = chunk;
} else {
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
output = (output as any).concat(chunk);
} catch (e) {
output = undefined;
concatSupported = false;
}
}
}
}
} catch (e) {
await runManager?.handleChainError(e);
throw e;
}
await runManager?.handleChainEnd(_coerceToDict(output, "output"));
}

_patchConfig(
config: Partial<CallOptions> = {},
callbackManager: CallbackManager | undefined = undefined
Expand All @@ -160,6 +204,11 @@ export abstract class Runnable<
});
}

transform?(
generator: AsyncGenerator<RunInput>,
options: Partial<CallOptions>
): AsyncGenerator<RunOutput>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
static isRunnable(thing: any): thing is Runnable {
return thing.lc_runnable;
Expand Down Expand Up @@ -314,8 +363,17 @@ export class RunnableSequence<
_coerceToDict(input, "input")
);
let nextStepInput = input;
const steps = [this.first, ...this.middle, this.last];
// Find the index of the last runnable in the sequence that doesn't have a .transform() method
// and start streaming from there
const streamingStartStepIndex =
steps.length -
[...steps]
.reverse()
.findIndex((step) => typeof step.transform !== "function") -
1;
try {
for (const step of [this.first, ...this.middle]) {
for (const step of steps.slice(0, streamingStartStepIndex)) {
nextStepInput = await step.invoke(
nextStepInput,
this._patchConfig(options, runManager?.getChild())
Expand All @@ -328,11 +386,18 @@ export class RunnableSequence<
let concatSupported = true;
let finalOutput;
try {
const iterator = await this.last._streamIterator(
let finalGenerator = await steps[streamingStartStepIndex]._streamIterator(
nextStepInput,
this._patchConfig(options, runManager?.getChild())
);
for await (const chunk of iterator) {
for (const step of steps.slice(streamingStartStepIndex + 1)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
finalGenerator = await step.transform!(
finalGenerator,
this._patchConfig(options, runManager?.getChild())
);
}
for await (const chunk of finalGenerator) {
yield chunk;
if (concatSupported) {
if (finalOutput === undefined) {
Expand Down
35 changes: 35 additions & 0 deletions langchain/src/schema/tests/output_parser.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* eslint-disable no-promise-executor-return */

import { test } from "@jest/globals";
import { LLM } from "../../llms/base.js";
import { GenerationChunk } from "../index.js";
import { BytesOutputParser } from "../output_parser.js";

class FakeStreamingLLM extends LLM {
_llmType() {
return "fake";
}

async _call(prompt: string): Promise<string> {
return prompt;
}

async *_streamResponseChunks(input: string) {
for (const c of input) {
await new Promise((resolve) => setTimeout(resolve, 50));
yield { text: c, generationInfo: {} } as GenerationChunk;
}
}
}

test("BytesOutputParser", async () => {
const llm = new FakeStreamingLLM({});
const stream = await llm.pipe(new BytesOutputParser()).stream("Hi there!");
const chunks = [];
const decoder = new TextDecoder();
for await (const chunk of stream) {
chunks.push(decoder.decode(chunk));
}
expect(chunks.length).toEqual("Hi there!".length);
expect(chunks.join("")).toEqual("Hi there!");
});
53 changes: 52 additions & 1 deletion langchain/src/schema/tests/runnable.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
/* eslint-disable no-promise-executor-return */

import { z } from "zod";
import { test } from "@jest/globals";
import { LLM } from "../../llms/base.js";
import {
BaseChatModel,
createChatMessageChunkEncoderStream,
} from "../../chat_models/base.js";
import { AIMessage, BaseMessage, ChatResult } from "../index.js";
import {
AIMessage,
BaseMessage,
ChatResult,
GenerationChunk,
} from "../index.js";
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand All @@ -28,6 +35,23 @@ class FakeLLM extends LLM {
}
}

class FakeStreamingLLM extends LLM {
_llmType() {
return "fake";
}

async _call(prompt: string): Promise<string> {
return prompt;
}

async *_streamResponseChunks(input: string) {
for (const c of input) {
await new Promise((resolve) => setTimeout(resolve, 50));
yield { text: c, generationInfo: {} } as GenerationChunk;
}
}
}

class FakeChatModel extends BaseChatModel {
_combineLLMOutput() {
return [];
Expand Down Expand Up @@ -195,3 +219,30 @@ test("Bind kwargs to a runnable with a batch call", async () => {
console.log(result);
expect(result).toEqual(["testing", "testing", "testing", "testing"]);
});

test("Stream the entire way through", async () => {
const llm = new FakeStreamingLLM({});
const stream = await llm.pipe(new StringOutputParser()).stream("Hi there!");
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
}
expect(chunks.length).toEqual("Hi there!".length);
expect(chunks.join("")).toEqual("Hi there!");
});

test("Don't use intermediate streaming", async () => {
const llm = new FakeStreamingLLM({});
const stream = await llm
.pipe(new StringOutputParser())
.pipe(new FakeLLM({}))
.stream("Hi there!");
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
console.log(chunk);
}
expect(chunks.length).toEqual(1);
expect(chunks[0]).toEqual("Hi there!");
});

1 comment on commit ae9895f

@vercel
Copy link

@vercel vercel bot commented on ae9895f Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.