Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): Prevent cache misses from triggering model start callback runs twice #7565

Merged
merged 5 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 80 additions & 40 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
coerceMessageLikeToMessage,
AIMessageChunk,
isAIMessageChunk,
isBaseMessage,
isAIMessage,
} from "../messages/index.js";
import type { BasePromptValueInterface } from "../prompt_values.js";
import {
Expand Down Expand Up @@ -343,41 +345,50 @@ export abstract class BaseChatModel<
async _generateUncached(
messages: BaseMessageLike[][],
parsedOptions: this["ParsedCallOptions"],
handledOptions: RunnableConfig
handledOptions: RunnableConfig,
startedRunManagers?: CallbackManagerForLLMRun[]
): Promise<LLMResult> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);

const inheritableMetadata = {
...handledOptions.metadata,
...this.getLsParams(parsedOptions),
};
// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
inheritableMetadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
baseMessages,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);
let runManagers: CallbackManagerForLLMRun[] | undefined;
if (
startedRunManagers !== undefined &&
startedRunManagers.length === baseMessages.length
) {
runManagers = startedRunManagers;
} else {
const inheritableMetadata = {
...handledOptions.metadata,
...this.getLsParams(parsedOptions),
};
// create callback manager and start run
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
inheritableMetadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
};
runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
baseMessages,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions.runName
);
}
const generations: ChatGeneration[][] = [];
const llmOutputs: LLMResult["llmOutput"][] = [];
// Even if stream is not explicitly called, check if model is implicitly
Expand Down Expand Up @@ -511,7 +522,12 @@ export abstract class BaseChatModel<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsedOptions: any;
handledOptions: RunnableConfig;
}): Promise<LLMResult & { missingPromptIndices: number[] }> {
}): Promise<
LLMResult & {
missingPromptIndices: number[];
startedRunManagers?: CallbackManagerForLLMRun[];
}
> {
const baseMessages = messages.map((messageList) =>
messageList.map(coerceMessageLikeToMessage)
);
Expand Down Expand Up @@ -580,7 +596,26 @@ export abstract class BaseChatModel<
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
generations[i] = result.map((result) => {
if (
"message" in result &&
isBaseMessage(result.message) &&
isAIMessage(result.message)
) {
// eslint-disable-next-line no-param-reassign
result.message.usage_metadata = {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
}
// eslint-disable-next-line no-param-reassign
result.generationInfo = {
...result.generationInfo,
tokenUsage: {},
};
return result;
});
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
Expand All @@ -598,6 +633,7 @@ export abstract class BaseChatModel<
const output = {
generations,
missingPromptIndices,
startedRunManagers: runManagers,
};

// This defines RUN_KEY as a non-enumerable property on the output object
Expand Down Expand Up @@ -650,20 +686,24 @@ export abstract class BaseChatModel<
callOptions as CallOptions
);

const { generations, missingPromptIndices } = await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});
const { generations, missingPromptIndices, startedRunManagers } =
await this._generateCached({
messages: baseMessages,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
const results = await this._generateUncached(
missingPromptIndices.map((i) => baseMessages[i]),
callOptions,
runnableConfig
runnableConfig,
startedRunManagers !== undefined
? missingPromptIndices.map((i) => startedRunManagers?.[i])
: undefined
);
await Promise.all(
results.generations.map(async (generation, index) => {
Expand Down
98 changes: 62 additions & 36 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,32 +240,41 @@ export abstract class BaseLLM<
async _generateUncached(
prompts: string[],
parsedOptions: this["ParsedCallOptions"],
handledOptions: BaseCallbackConfig
handledOptions: BaseCallbackConfig,
startedRunManagers?: CallbackManagerForLLMRun[]
): Promise<LLMResult> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);
let runManagers: CallbackManagerForLLMRun[] | undefined;
if (
startedRunManagers !== undefined &&
startedRunManagers.length === prompts.length
) {
runManagers = startedRunManagers;
} else {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
handledOptions.tags,
this.tags,
handledOptions.metadata,
this.metadata,
{ verbose: this.verbose }
);
const extra = {
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
};
runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
prompts,
handledOptions.runId,
undefined,
extra,
undefined,
undefined,
handledOptions?.runName
);
}
// Even if stream is not explicitly called, check if model is implicitly
// called from streamEvents() or streamLog() to get all streamed events.
// Bail out if _streamResponseChunks not overridden
Expand Down Expand Up @@ -346,7 +355,12 @@ export abstract class BaseLLM<
parsedOptions: any;
handledOptions: RunnableConfig;
runId?: string;
}): Promise<LLMResult & { missingPromptIndices: number[] }> {
}): Promise<
LLMResult & {
missingPromptIndices: number[];
startedRunManagers?: CallbackManagerForLLMRun[];
}
> {
const callbackManager_ = await CallbackManager.configure(
handledOptions.callbacks,
this.callbacks,
Expand Down Expand Up @@ -401,7 +415,14 @@ export abstract class BaseLLM<
cachedResults.map(async ({ result: promiseResult, runManager }, i) => {
if (promiseResult.status === "fulfilled") {
const result = promiseResult.value as Generation[];
generations[i] = result;
generations[i] = result.map((result) => {
// eslint-disable-next-line no-param-reassign
result.generationInfo = {
...result.generationInfo,
tokenUsage: {},
};
return result;
});
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
Expand All @@ -419,6 +440,7 @@ export abstract class BaseLLM<
const output = {
generations,
missingPromptIndices,
startedRunManagers: runManagers,
};

// This defines RUN_KEY as a non-enumerable property on the output object
Expand Down Expand Up @@ -465,21 +487,25 @@ export abstract class BaseLLM<
const llmStringKey = this._getSerializedCacheKeyParametersForCall(
callOptions as CallOptions
);
const { generations, missingPromptIndices } = await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
runId: runnableConfig.runId,
});
const { generations, missingPromptIndices, startedRunManagers } =
await this._generateCached({
prompts,
cache,
llmStringKey,
parsedOptions: callOptions,
handledOptions: runnableConfig,
runId: runnableConfig.runId,
});

let llmOutput = {};
if (missingPromptIndices.length > 0) {
const results = await this._generateUncached(
missingPromptIndices.map((i) => prompts[i]),
callOptions,
runnableConfig
runnableConfig,
startedRunManagers !== undefined
? missingPromptIndices.map((i) => startedRunManagers?.[i])
: undefined
);
await Promise.all(
results.generations.map(async (generation, index) => {
Expand Down
48 changes: 48 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,54 @@ test("Test ChatModel can cache complex messages", async () => {
expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2));
});

test("Test ChatModel with cache does not start multiple chat model runs", async () => {
const model = new FakeChatModel({
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const contentToCache = [
{
type: "text",
text: "Hello there again!",
},
];
const humanMessage = new HumanMessage({
content: contentToCache,
});

const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({});

const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeNull();

// Invoke model to trigger cache update
const eventStream = model.streamEvents([humanMessage], { version: "v2" });

expect(await model.cache.lookup(prompt, llmKey)).toBeDefined();

const events = [];
for await (const event of eventStream) {
events.push(event);
}
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_chat_model_start");
expect(events[1].event).toEqual("on_chat_model_end");

const eventStream2 = model.streamEvents([humanMessage], { version: "v2" });

const events2 = [];
for await (const event of eventStream2) {
events2.push(event);
}
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_chat_model_start");
expect(events2[1].event).toEqual("on_chat_model_end");
});

test("Test ChatModel can emit a custom event", async () => {
const model = new FakeListChatModel({
responses: ["hi"],
Expand Down
Loading
Loading