Skip to content

Commit

Permalink
fix(core): Fix trim messages mutation (#7585)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Jan 25, 2025
1 parent a71067d commit 93862f4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
36 changes: 35 additions & 1 deletion langchain-core/src/messages/tests/message_utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { it, describe, test, expect } from "@jest/globals";
import { v4 } from "uuid";
import {
filterMessages,
mergeMessageRuns,
Expand Down Expand Up @@ -198,6 +199,34 @@ describe("trimMessages can trim", () => {
};
};

it("should not mutate messages", async () => {
const messages: BaseMessage[] = [
new HumanMessage({
content: `My name is Jane Doe.
this is a long text
`,
id: v4(),
}),
new HumanMessage({
content: `My name is Jane Doe.feiowfjoaejfioewaijof ewoif ioawej foiaew iofewi ao
this is a longer text than the first text.
`,
id: v4(),
}),
];

const repr = JSON.stringify(messages);

await trimMessages(messages, {
maxTokens: 14,
strategy: "last",
tokenCounter: () => 100,
allowPartial: true,
});

expect(repr).toEqual(JSON.stringify(messages));
});

it("should not mutate messages if no trimming occurs with strategy last", async () => {
const trimmer = trimMessages({
maxTokens: 128000,
Expand All @@ -211,6 +240,8 @@ describe("trimMessages can trim", () => {
content: "Fetch the last 5 emails from Flora Testington's inbox.",
additional_kwargs: {},
response_metadata: {},
id: undefined,
name: undefined,
}),
new AIMessageChunk({
id: "chatcmpl-abcdefg",
Expand Down Expand Up @@ -258,18 +289,21 @@ describe("trimMessages can trim", () => {
name: "getEmails",
args: '{"inboxName":"flora@foo.org","amount":5,"folder":"Inbox","searchString":null,"from":null,"subject":null,"cc":[],"bcc":[]}',
id: "foobarbaz",
index: 0,
type: "tool_call_chunk",
},
],
invalid_tool_calls: [],
name: undefined,
}),
new ToolMessage({
content: "a whole bunch of emails!",
name: "getEmails",
additional_kwargs: {},
response_metadata: {},
tool_call_id: "foobarbaz",
artifact: undefined,
id: undefined,
status: undefined,
}),
];
const trimmedMessages = await trimmer.invoke(messages);
Expand Down
37 changes: 27 additions & 10 deletions langchain-core/src/messages/transformers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
MessageType,
BaseMessageChunk,
BaseMessageFields,
isBaseMessageChunk,
} from "./base.js";
import {
ChatMessage,
Expand Down Expand Up @@ -56,16 +57,16 @@ const _isMessageType = (msg: BaseMessage, types: MessageTypeOrClass[]) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const instantiatedMsgClass = new (t as any)({});
if (
!("_getType" in instantiatedMsgClass) ||
typeof instantiatedMsgClass._getType !== "function"
!("getType" in instantiatedMsgClass) ||
typeof instantiatedMsgClass.getType !== "function"
) {
throw new Error("Invalid type provided.");
}
return instantiatedMsgClass._getType();
return instantiatedMsgClass.getType();
})
),
];
const msgType = msg._getType();
const msgType = msg.getType();
return typesAsStrings.some((t) => t === msgType);
};

Expand Down Expand Up @@ -279,8 +280,8 @@ function _mergeMessageRuns(messages: BaseMessage[]): BaseMessage[] {
if (!last) {
merged.push(curr);
} else if (
curr._getType() === "tool" ||
!(curr._getType() === last._getType())
curr.getType() === "tool" ||
!(curr.getType() === last.getType())
) {
merged.push(last, curr);
} else {
Expand Down Expand Up @@ -767,7 +768,7 @@ async function _firstMaxTokens(
([k]) => k !== "type" && !k.startsWith("lc_")
)
) as BaseMessageFields;
const updatedMessage = _switchTypeToMessage(excluded._getType(), {
const updatedMessage = _switchTypeToMessage(excluded.getType(), {
...fields,
content: partialContent,
});
Expand Down Expand Up @@ -862,7 +863,18 @@ async function _lastMaxTokens(
} = options;

// Create a copy of messages to avoid mutation
let messagesCopy = [...messages];
let messagesCopy = messages.map((message) => {
const fields = Object.fromEntries(
Object.entries(message).filter(
([k]) => k !== "type" && !k.startsWith("lc_")
)
) as BaseMessageFields;
return _switchTypeToMessage(
message.getType(),
fields,
isBaseMessageChunk(message)
);
});

if (endOn) {
const endOnArr = Array.isArray(endOn) ? endOn : [endOn];
Expand All @@ -875,7 +887,7 @@ async function _lastMaxTokens(
}

const swappedSystem =
includeSystem && messagesCopy[0]?._getType() === "system";
includeSystem && messagesCopy[0]?.getType() === "system";
let reversed_ = swappedSystem
? messagesCopy.slice(0, 1).concat(messagesCopy.slice(1).reverse())
: messagesCopy.reverse();
Expand Down Expand Up @@ -943,6 +955,11 @@ function _switchTypeToMessage(
fields: BaseMessageFields,
returnChunk: true
): BaseMessageChunk;
function _switchTypeToMessage(
messageType: MessageType,
fields: BaseMessageFields,
returnChunk?: boolean
): BaseMessageChunk | BaseMessage;
function _switchTypeToMessage(
messageType: MessageType,
fields: BaseMessageFields,
Expand Down Expand Up @@ -1058,7 +1075,7 @@ function _switchTypeToMessage(
}

function _chunkToMsg(chunk: BaseMessageChunk): BaseMessage {
const chunkType = chunk._getType();
const chunkType = chunk.getType();
let msg: BaseMessage | undefined;
const fields = Object.fromEntries(
Object.entries(chunk).filter(
Expand Down

0 comments on commit 93862f4

Please sign in to comment.