From 93862f4f9f7c5b36b5fa59142546a3c69c7f1ccb Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Fri, 24 Jan 2025 17:34:09 -0800 Subject: [PATCH] fix(core): Fix trim messages mutation (#7585) --- .../src/messages/tests/message_utils.test.ts | 36 +++++++++++++++++- langchain-core/src/messages/transformers.ts | 37 ++++++++++++++----- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/langchain-core/src/messages/tests/message_utils.test.ts b/langchain-core/src/messages/tests/message_utils.test.ts index a570ee478f60..1e01b4e7687d 100644 --- a/langchain-core/src/messages/tests/message_utils.test.ts +++ b/langchain-core/src/messages/tests/message_utils.test.ts @@ -1,4 +1,5 @@ import { it, describe, test, expect } from "@jest/globals"; +import { v4 } from "uuid"; import { filterMessages, mergeMessageRuns, @@ -198,6 +199,34 @@ describe("trimMessages can trim", () => { }; }; + it("should not mutate messages", async () => { + const messages: BaseMessage[] = [ + new HumanMessage({ + content: `My name is Jane Doe. + this is a long text + `, + id: v4(), + }), + new HumanMessage({ + content: `My name is Jane Doe.feiowfjoaejfioewaijof ewoif ioawej foiaew iofewi ao + this is a longer text than the first text. + `, + id: v4(), + }), + ]; + + const repr = JSON.stringify(messages); + + await trimMessages(messages, { + maxTokens: 14, + strategy: "last", + tokenCounter: () => 100, + allowPartial: true, + }); + + expect(repr).toEqual(JSON.stringify(messages)); + }); + it("should not mutate messages if no trimming occurs with strategy last", async () => { const trimmer = trimMessages({ maxTokens: 128000, @@ -211,6 +240,8 @@ describe("trimMessages can trim", () => { content: "Fetch the last 5 emails from Flora Testington's inbox.", additional_kwargs: {}, response_metadata: {}, + id: undefined, + name: undefined, }), new AIMessageChunk({ id: "chatcmpl-abcdefg", @@ -258,11 +289,11 @@ describe("trimMessages can trim", () => { name: "getEmails", args: '{"inboxName":"flora@foo.org","amount":5,"folder":"Inbox","searchString":null,"from":null,"subject":null,"cc":[],"bcc":[]}', id: "foobarbaz", - index: 0, type: "tool_call_chunk", }, ], invalid_tool_calls: [], + name: undefined, }), new ToolMessage({ content: "a whole bunch of emails!", @@ -270,6 +301,9 @@ describe("trimMessages can trim", () => { additional_kwargs: {}, response_metadata: {}, tool_call_id: "foobarbaz", + artifact: undefined, + id: undefined, + status: undefined, }), ]; const trimmedMessages = await trimmer.invoke(messages); diff --git a/langchain-core/src/messages/transformers.ts b/langchain-core/src/messages/transformers.ts index c96ecd69ce48..ffd2d862cd94 100644 --- a/langchain-core/src/messages/transformers.ts +++ b/langchain-core/src/messages/transformers.ts @@ -7,6 +7,7 @@ import { MessageType, BaseMessageChunk, BaseMessageFields, + isBaseMessageChunk, } from "./base.js"; import { ChatMessage, @@ -56,16 +57,16 @@ const _isMessageType = (msg: BaseMessage, types: MessageTypeOrClass[]) => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const instantiatedMsgClass = new (t as any)({}); if ( - !("_getType" in instantiatedMsgClass) || - typeof instantiatedMsgClass._getType !== "function" + !("getType" in instantiatedMsgClass) || + typeof instantiatedMsgClass.getType !== "function" ) { throw new Error("Invalid type provided."); } - return instantiatedMsgClass._getType(); + return instantiatedMsgClass.getType(); }) ), ]; - const msgType = msg._getType(); + const msgType = msg.getType(); return typesAsStrings.some((t) => t === msgType); }; @@ -279,8 +280,8 @@ function _mergeMessageRuns(messages: BaseMessage[]): BaseMessage[] { if (!last) { merged.push(curr); } else if ( - curr._getType() === "tool" || - !(curr._getType() === last._getType()) + curr.getType() === "tool" || + !(curr.getType() === last.getType()) ) { merged.push(last, curr); } else { @@ -767,7 +768,7 @@ async function _firstMaxTokens( ([k]) => k !== "type" && !k.startsWith("lc_") ) ) as BaseMessageFields; - const updatedMessage = _switchTypeToMessage(excluded._getType(), { + const updatedMessage = _switchTypeToMessage(excluded.getType(), { ...fields, content: partialContent, }); @@ -862,7 +863,18 @@ async function _lastMaxTokens( } = options; // Create a copy of messages to avoid mutation - let messagesCopy = [...messages]; + let messagesCopy = messages.map((message) => { + const fields = Object.fromEntries( + Object.entries(message).filter( + ([k]) => k !== "type" && !k.startsWith("lc_") + ) + ) as BaseMessageFields; + return _switchTypeToMessage( + message.getType(), + fields, + isBaseMessageChunk(message) + ); + }); if (endOn) { const endOnArr = Array.isArray(endOn) ? endOn : [endOn]; @@ -875,7 +887,7 @@ async function _lastMaxTokens( } const swappedSystem = - includeSystem && messagesCopy[0]?._getType() === "system"; + includeSystem && messagesCopy[0]?.getType() === "system"; let reversed_ = swappedSystem ? messagesCopy.slice(0, 1).concat(messagesCopy.slice(1).reverse()) : messagesCopy.reverse(); @@ -943,6 +955,11 @@ function _switchTypeToMessage( fields: BaseMessageFields, returnChunk: true ): BaseMessageChunk; +function _switchTypeToMessage( + messageType: MessageType, + fields: BaseMessageFields, + returnChunk?: boolean +): BaseMessageChunk | BaseMessage; function _switchTypeToMessage( messageType: MessageType, fields: BaseMessageFields, @@ -1058,7 +1075,7 @@ function _switchTypeToMessage( } function _chunkToMsg(chunk: BaseMessageChunk): BaseMessage { - const chunkType = chunk._getType(); + const chunkType = chunk.getType(); let msg: BaseMessage | undefined; const fields = Object.fromEntries( Object.entries(chunk).filter(