diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 8adea3e83b4e..0e8af1bc32bf 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -207,18 +207,6 @@ export interface BaseLanguageModelCallOptions extends RunnableConfig { * If not provided, the default stop tokens for the model will be used. */ stop?: string[]; - - /** - * Timeout for this call in milliseconds. - */ - timeout?: number; - - /** - * Abort signal for this call. - * If provided, the call will be aborted when the signal is aborted. - * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal - */ - signal?: AbortSignal; } export interface FunctionDefinition { diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 30256073c33a..b60baecb3a20 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -145,7 +145,7 @@ export abstract class BaseChatModel< > extends BaseLanguageModel { declare ParsedCallOptions: Omit< CallOptions, - keyof RunnableConfig & "timeout" + Exclude >; // Only ever instantiated in main LangChain @@ -159,14 +159,13 @@ export abstract class BaseChatModel< ...llmOutputs: LLMResult["llmOutput"][] ): LLMResult["llmOutput"]; - protected _separateRunnableConfigFromCallOptions( + protected _separateRunnableConfigFromCallOptionsCompat( options?: Partial ): [RunnableConfig, this["ParsedCallOptions"]] { + // For backwards compat, keep `signal` in both runnableConfig and callOptions const [runnableConfig, callOptions] = super._separateRunnableConfigFromCallOptions(options); - if (callOptions?.timeout && !callOptions.signal) { - callOptions.signal = AbortSignal.timeout(callOptions.timeout); - } + (callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal; return [runnableConfig, callOptions as this["ParsedCallOptions"]]; } @@ -232,7 +231,7 @@ export abstract class BaseChatModel< const prompt = BaseChatModel._convertInputToPromptValue(input); const messages = prompt.toChatMessages(); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(options); + this._separateRunnableConfigFromCallOptionsCompat(options); const inheritableMetadata = { ...runnableConfig.metadata, @@ -578,7 +577,7 @@ export abstract class BaseChatModel< ); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(parsedOptions); + this._separateRunnableConfigFromCallOptionsCompat(parsedOptions); runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks; if (!this.cache) { @@ -586,8 +585,9 @@ export abstract class BaseChatModel< } const { cache } = this; - const llmStringKey = - this._getSerializedCacheKeyParametersForCall(callOptions); + const llmStringKey = this._getSerializedCacheKeyParametersForCall( + callOptions as CallOptions + ); const { generations, missingPromptIndices } = await this._generateCached({ messages: baseMessages, diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index 20b0e812deb7..f6df3677dfc1 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -63,7 +63,7 @@ export abstract class BaseLLM< > extends BaseLanguageModel { declare ParsedCallOptions: Omit< CallOptions, - keyof RunnableConfig & "timeout" + Exclude >; // Only ever instantiated in main LangChain @@ -103,14 +103,13 @@ export abstract class BaseLLM< throw new Error("Not implemented."); } - protected _separateRunnableConfigFromCallOptions( + protected _separateRunnableConfigFromCallOptionsCompat( options?: Partial ): [RunnableConfig, this["ParsedCallOptions"]] { + // For backwards compat, keep `signal` in both runnableConfig and callOptions const [runnableConfig, callOptions] = super._separateRunnableConfigFromCallOptions(options); - if (callOptions?.timeout && !callOptions.signal) { - callOptions.signal = AbortSignal.timeout(callOptions.timeout); - } + (callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal; return [runnableConfig, callOptions as this["ParsedCallOptions"]]; } @@ -126,7 +125,7 @@ export abstract class BaseLLM< } else { const prompt = BaseLLM._convertInputToPromptValue(input); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(options); + this._separateRunnableConfigFromCallOptionsCompat(options); const callbackManager_ = await CallbackManager.configure( runnableConfig.callbacks, this.callbacks, @@ -461,7 +460,7 @@ export abstract class BaseLLM< } const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(parsedOptions); + this._separateRunnableConfigFromCallOptionsCompat(parsedOptions); runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks; if (!this.cache) { @@ -469,8 +468,9 @@ export abstract class BaseLLM< } const { cache } = this; - const llmStringKey = - this._getSerializedCacheKeyParametersForCall(callOptions); + const llmStringKey = this._getSerializedCacheKeyParametersForCall( + callOptions as CallOptions + ); const { generations, missingPromptIndices } = await this._generateCached({ prompts, cache, diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index 4cb65f379f22..ba19752c8f10 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -339,6 +339,8 @@ export abstract class Runnable< recursionLimit: options.recursionLimit, maxConcurrency: options.maxConcurrency, runId: options.runId, + timeout: options.timeout, + signal: options.signal, }); } const callOptions = { ...(options as Partial) }; @@ -350,6 +352,8 @@ export abstract class Runnable< delete callOptions.recursionLimit; delete callOptions.maxConcurrency; delete callOptions.runId; + delete callOptions.timeout; + delete callOptions.signal; return [runnableConfig, callOptions]; } @@ -378,7 +382,17 @@ export abstract class Runnable< delete config.runId; let output; try { - output = await func.call(this, input, config, runManager); + const promise = func.call(this, input, config, runManager); + output = options?.signal + ? await Promise.race([ + promise, + new Promise((_, reject) => { + options.signal?.addEventListener("abort", () => { + reject(new Error("AbortError")); + }); + }), + ]) + : await promise; } catch (e) { await runManager?.handleChainError(e); throw e; @@ -430,13 +444,23 @@ export abstract class Runnable< ); let outputs: (RunOutput | Error)[]; try { - outputs = await func.call( + const promise = func.call( this, inputs, optionsList, runManagers, batchOptions ); + outputs = optionsList?.[0]?.signal + ? await Promise.race([ + promise, + new Promise((_, reject) => { + optionsList?.[0]?.signal?.addEventListener("abort", () => { + reject(new Error("AbortError")); + }); + }), + ]) + : await promise; } catch (e) { await Promise.all( runManagers.map((runManager) => runManager?.handleChainError(e)) @@ -509,6 +533,7 @@ export abstract class Runnable< undefined, config.runName ?? this.getName() ), + options?.signal, config ); delete config.runId; @@ -1750,14 +1775,27 @@ export class RunnableSequence< const initialSteps = [this.first, ...this.middle]; for (let i = 0; i < initialSteps.length; i += 1) { const step = initialSteps[i]; - nextStepInput = await step.invoke( + const promise = step.invoke( nextStepInput, patchConfig(config, { callbacks: runManager?.getChild(`seq:step:${i + 1}`), }) ); + nextStepInput = options?.signal + ? await Promise.race([ + promise, + new Promise((_, reject) => { + options.signal?.addEventListener("abort", () => + reject(new Error("Aborted")) + ); + }), + ]) + : await promise; } // TypeScript can't detect that the last output of the sequence returns RunOutput, so call it out of the loop here + if (options?.signal?.aborted) { + throw new Error("Aborted"); + } finalOutput = await this.last.invoke( nextStepInput, patchConfig(config, { @@ -1819,7 +1857,7 @@ export class RunnableSequence< try { for (let i = 0; i < this.steps.length; i += 1) { const step = this.steps[i]; - nextStepInputs = await step.batch( + const promise = step.batch( nextStepInputs, runManagers.map((runManager, j) => { const childRunManager = runManager?.getChild(`seq:step:${i + 1}`); @@ -1827,6 +1865,16 @@ export class RunnableSequence< }), batchOptions ); + nextStepInputs = configList[0]?.signal + ? await Promise.race([ + promise, + new Promise((_, reject) => { + configList[0]?.signal?.addEventListener("abort", () => + reject(new Error("Aborted")) + ); + }), + ]) + : await promise; } } catch (e) { await Promise.all( @@ -1880,6 +1928,7 @@ export class RunnableSequence< ); } for await (const chunk of finalGenerator) { + options?.signal?.throwIfAborted(); yield chunk; if (concatSupported) { if (finalOutput === undefined) { @@ -2058,16 +2107,26 @@ export class RunnableMap< // eslint-disable-next-line @typescript-eslint/no-explicit-any const output: Record = {}; try { - await Promise.all( - Object.entries(this.steps).map(async ([key, runnable]) => { + const promises = Object.entries(this.steps).map( + async ([key, runnable]) => { output[key] = await runnable.invoke( input, patchConfig(config, { callbacks: runManager?.getChild(`map:key:${key}`), }) ); - }) + } ); + if (options?.signal) { + promises.push( + new Promise((_, reject) => { + options.signal?.addEventListener("abort", () => + reject(new Error("Aborted")) + ); + }) + ); + } + await Promise.all(promises); } catch (e) { await runManager?.handleChainError(e); throw e; @@ -2101,7 +2160,17 @@ export class RunnableMap< // starting new iterations as needed, // until all iterators are done while (tasks.size) { - const { key, result, gen } = await Promise.race(tasks.values()); + const promise = Promise.race(tasks.values()); + const { key, result, gen } = options?.signal + ? await Promise.race([ + promise, + new Promise((_, reject) => { + options.signal?.addEventListener("abort", () => + reject(new Error("Aborted")) + ); + }), + ]) + : await promise; tasks.delete(key); if (!result.done) { yield { [key]: result.value } as unknown as RunOutput; @@ -2172,21 +2241,33 @@ export class RunnableTraceable extends Runnable< async invoke(input: RunInput, options?: Partial) { const [config] = this._getOptionsList(options ?? {}, 1); const callbacks = await getCallbackManagerForConfig(config); - - return (await this.func( + const promise = this.func( patchConfig(config, { callbacks }), input - )) as RunOutput; + ) as Promise; + + return config?.signal + ? Promise.race([ + promise, + new Promise((_, reject) => { + config.signal?.addEventListener("abort", () => + reject(new Error("Aborted")) + ); + }), + ]) + : await promise; } async *_streamIterator( input: RunInput, options?: Partial ): AsyncGenerator { + const [config] = this._getOptionsList(options ?? {}, 1); const result = await this.invoke(input, options); if (isAsyncIterable(result)) { for await (const item of result) { + config?.signal?.throwIfAborted(); yield item as RunOutput; } return; @@ -2194,6 +2275,7 @@ export class RunnableTraceable extends Runnable< if (isIterator(result)) { while (true) { + config?.signal?.throwIfAborted(); const state: IteratorResult = result.next(); if (state.done) break; yield state.value as RunOutput; @@ -2320,6 +2402,7 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); if (finalOutput === undefined) { finalOutput = chunk as RunOutput; } else { @@ -2339,6 +2422,7 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); if (finalOutput === undefined) { finalOutput = chunk as RunOutput; } else { @@ -2423,10 +2507,12 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); yield chunk as RunOutput; } } else if (isIterableIterator(output)) { for (const chunk of consumeIteratorInContext(childConfig, output)) { + config?.signal?.throwIfAborted(); yield chunk as RunOutput; } } else { @@ -2517,6 +2603,7 @@ export class RunnableWithFallbacks extends Runnable< ); let firstError; for (const runnable of this.runnables()) { + config?.signal?.throwIfAborted(); try { const output = await runnable.invoke( input, @@ -2586,6 +2673,7 @@ export class RunnableWithFallbacks extends Runnable< // eslint-disable-next-line @typescript-eslint/no-explicit-any let firstError: any; for (const runnable of this.runnables()) { + configList[0].signal?.throwIfAborted(); try { const outputs = await runnable.batch( inputs, diff --git a/langchain-core/src/runnables/config.ts b/langchain-core/src/runnables/config.ts index 409d556eac8d..04687ee93a1d 100644 --- a/langchain-core/src/runnables/config.ts +++ b/langchain-core/src/runnables/config.ts @@ -31,6 +31,18 @@ export function mergeConfigs( copy[key] = [...new Set(baseKeys.concat(options[key] ?? []))]; } else if (key === "configurable") { copy[key] = { ...copy[key], ...options[key] }; + } else if (key === "timeout") { + if (copy.timeout === undefined) { + copy.timeout = options.timeout; + } else if (options.timeout !== undefined) { + copy.timeout = Math.min(copy.timeout, options.timeout); + } + } else if (key === "signal") { + if (copy.signal === undefined) { + copy.signal = options.signal; + } else if (options.signal !== undefined) { + copy.signal = AbortSignal.any([copy.signal, options.signal]); + } } else if (key === "callbacks") { const baseCallbacks = copy.callbacks; const providedCallbacks = options.callbacks; @@ -155,6 +167,18 @@ export function ensureConfig( } } } + if (empty.timeout !== undefined) { + if (empty.timeout <= 0) { + throw new Error("Timeout must be a positive number"); + } + const timeoutSignal = AbortSignal.timeout(empty.timeout); + if (empty.signal !== undefined) { + empty.signal = AbortSignal.any([empty.signal, timeoutSignal]); + } else { + empty.signal = timeoutSignal; + } + delete empty.timeout; + } return empty as CallOptions; } diff --git a/langchain-core/src/runnables/remote.ts b/langchain-core/src/runnables/remote.ts index 9ecd597556f9..dc08c731a501 100644 --- a/langchain-core/src/runnables/remote.ts +++ b/langchain-core/src/runnables/remote.ts @@ -214,11 +214,12 @@ function deserialize(str: string): RunOutput { return revive(obj); } -function removeCallbacks( +function removeCallbacksAndSignal( options?: RunnableConfig -): Omit { +): Omit { const rest = { ...options }; delete rest.callbacks; + delete rest.signal; return rest; } @@ -276,7 +277,7 @@ export class RemoteRunnable< this.options = options; } - private async post(path: string, body: Body) { + private async post(path: string, body: Body, signal?: AbortSignal) { return fetch(`${this.url}${path}`, { method: "POST", body: JSON.stringify(serialize(body)), @@ -284,7 +285,7 @@ export class RemoteRunnable< "Content-Type": "application/json", ...this.options?.headers, }, - signal: AbortSignal.timeout(this.options?.timeout ?? 60000), + signal: signal ?? AbortSignal.timeout(this.options?.timeout ?? 60000), }); } @@ -299,11 +300,15 @@ export class RemoteRunnable< input: RunInput; config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; - }>("/invoke", { - input, - config: removeCallbacks(config), - kwargs: kwargs ?? {}, - }); + }>( + "/invoke", + { + input, + config: removeCallbacksAndSignal(config), + kwargs: kwargs ?? {}, + }, + config.signal + ); if (!response.ok) { throw new Error(`${response.status} Error: ${await response.text()}`); } @@ -347,13 +352,17 @@ export class RemoteRunnable< inputs: RunInput[]; config?: (RunnableConfig & RunnableBatchOptions)[]; kwargs?: Omit, keyof RunnableConfig>[]; - }>("/batch", { - inputs, - config: (configs ?? []) - .map(removeCallbacks) - .map((config) => ({ ...config, ...batchOptions })), - kwargs, - }); + }>( + "/batch", + { + inputs, + config: (configs ?? []) + .map(removeCallbacksAndSignal) + .map((config) => ({ ...config, ...batchOptions })), + kwargs, + }, + options?.[0]?.signal + ); if (!response.ok) { throw new Error(`${response.status} Error: ${await response.text()}`); } @@ -422,11 +431,15 @@ export class RemoteRunnable< input: RunInput; config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; - }>("/stream", { - input, - config: removeCallbacks(config), - kwargs, - }); + }>( + "/stream", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + }, + config.signal + ); if (!response.ok) { const json = await response.json(); const error = new Error( @@ -502,13 +515,17 @@ export class RemoteRunnable< config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; diff: false; - }>("/stream_log", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); + }>( + "/stream_log", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }, + config.signal + ); const { body, ok } = response; if (!ok) { throw new Error(`${response.status} Error: ${await response.text()}`); @@ -574,13 +591,17 @@ export class RemoteRunnable< config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; diff: false; - }>("/stream_events", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); + }>( + "/stream_events", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }, + config.signal + ); const { body, ok } = response; if (!ok) { throw new Error(`${response.status} Error: ${await response.text()}`); diff --git a/langchain-core/src/runnables/types.ts b/langchain-core/src/runnables/types.ts index 569e8aa26c0e..e7ddfa8c3852 100644 --- a/langchain-core/src/runnables/types.ts +++ b/langchain-core/src/runnables/types.ts @@ -89,4 +89,16 @@ export interface RunnableConfig extends BaseCallbackConfig { /** Maximum number of parallel calls to make. */ maxConcurrency?: number; + + /** + * Timeout for this call in milliseconds. + */ + timeout?: number; + + /** + * Abort signal for this call. + * If provided, the call will be aborted when the signal is aborted. + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal + */ + signal?: AbortSignal; } diff --git a/langchain-core/src/utils/stream.ts b/langchain-core/src/utils/stream.ts index 234cec3b900f..31ac8907b1c4 100644 --- a/langchain-core/src/utils/stream.ts +++ b/langchain-core/src/utils/stream.ts @@ -186,6 +186,8 @@ export class AsyncGeneratorWithSetup< public config?: unknown; + public signal?: AbortSignal; + private firstResult: Promise>; private firstResultUsed = false; @@ -194,9 +196,11 @@ export class AsyncGeneratorWithSetup< generator: AsyncGenerator; startSetup?: () => Promise; config?: unknown; + signal?: AbortSignal; }) { this.generator = params.generator; this.config = params.config; + this.signal = params.signal; // setup is a promise that resolves only after the first iterator value // is available. this is useful when setup of several piped generators // needs to happen in logical order, ie. in the order in which input to @@ -218,6 +222,8 @@ export class AsyncGeneratorWithSetup< } async next(...args: [] | [TNext]): Promise> { + this.signal?.throwIfAborted(); + if (!this.firstResultUsed) { this.firstResultUsed = true; return this.firstResult; @@ -225,9 +231,20 @@ export class AsyncGeneratorWithSetup< return AsyncLocalStorageProviderSingleton.runWithConfig( this.config, - async () => { - return this.generator.next(...args); - }, + this.signal + ? async () => { + return Promise.race([ + this.generator.next(...args), + new Promise((_resolve, reject) => { + this.signal?.addEventListener("abort", () => { + reject(new Error("Aborted")); + }); + }), + ]); + } + : async () => { + return this.generator.next(...args); + }, true ); } @@ -264,11 +281,13 @@ export async function pipeGeneratorWithSetup< ) => AsyncGenerator, generator: AsyncGenerator, startSetup: () => Promise, + signal: AbortSignal | undefined, ...args: A ) { const gen = new AsyncGeneratorWithSetup({ generator, startSetup, + signal, }); const setup = await gen.setup; return { output: to(gen, setup, ...args), setup };