From 4110209494910fe2420541e13e7dbcf107a6a0b8 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Mon, 27 Jan 2025 20:37:38 -0500 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20fix:=20Prevent=20Instructi?= =?UTF-8?q?ons=20from=20Removal=20when=20nearing=20Max=20Context=20(#5516)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: getMessagesWithinTokenLimit to accept params object * refactor: always include instructions in payload if provided * ci: remove obsolete test * refactor: update logoutUser to accept request object and handle session destruction * test: enhance getMessagesWithinTokenLimit tests for instruction handling --- api/app/clients/AnthropicClient.js | 2 +- api/app/clients/BaseClient.js | 72 ++++++-- api/app/clients/OpenAIClient.js | 5 +- api/app/clients/specs/BaseClient.test.js | 172 +++++++++++------- .../controllers/auth/LogoutController.js | 2 +- api/server/services/AuthService.js | 15 +- 6 files changed, 185 insertions(+), 83 deletions(-) diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 8dc0e40d565..522b6beb4fb 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -416,7 +416,7 @@ class AnthropicClient extends BaseClient { } let { context: messagesInWindow, remainingContextTokens } = - await this.getMessagesWithinTokenLimit(formattedMessages); + await this.getMessagesWithinTokenLimit({ messages: formattedMessages }); const tokenCountMap = orderedMessages .slice(orderedMessages.length - messagesInWindow.length) diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index d8d9035b109..6cda4471421 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -347,25 +347,38 @@ class BaseClient { * If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array. * The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages. * - * @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest. - * @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`. - * @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`. + * @param {Object} params + * @param {TMessage[]} params.messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest. + * @param {number} [params.maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`. + * @param {{ role: 'system', content: text, tokenCount: number }} [params.instructions] - Instructions already added to the context at index 0. + * @returns {Promise<{ + * context: TMessage[], + * remainingContextTokens: number, + * messagesToRefine: TMessage[], + * summaryIndex: number, + * }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`. * `context` is an array of messages that fit within the token limit. * `summaryIndex` is the index of the first message in the `messagesToRefine` array. * `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. * `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. */ - async getMessagesWithinTokenLimit(_messages, maxContextTokens) { + async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) { // Every reply is primed with <|start|>assistant<|message|>, so we // start with 3 tokens for the label after all messages have been counted. - let currentTokenCount = 3; let summaryIndex = -1; - let remainingContextTokens = maxContextTokens ?? this.maxContextTokens; + let currentTokenCount = 3; + const instructionsTokenCount = instructions?.tokenCount ?? 0; + let remainingContextTokens = + (maxContextTokens ?? this.maxContextTokens) - instructionsTokenCount; const messages = [..._messages]; const context = []; + if (currentTokenCount < remainingContextTokens) { while (messages.length > 0 && currentTokenCount < remainingContextTokens) { + if (messages.length === 1 && instructions) { + break; + } const poppedMessage = messages.pop(); const { tokenCount } = poppedMessage; @@ -379,6 +392,11 @@ class BaseClient { } } + if (instructions) { + context.push(_messages[0]); + messages.shift(); + } + const prunedMemory = messages; summaryIndex = prunedMemory.length - 1; remainingContextTokens -= currentTokenCount; @@ -403,12 +421,18 @@ class BaseClient { if (instructions) { ({ tokenCount, ..._instructions } = instructions); } + _instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount); - let payload = this.addInstructions(formattedMessages, _instructions); - let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); + if (tokenCount && tokenCount > this.maxContextTokens) { + const info = `${tokenCount} / ${this.maxContextTokens}`; + const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; + logger.warn(`Instructions token count exceeds max token count (${info}).`); + throw new Error(errorMessage); + } + if (this.clientName === EModelEndpoint.agents) { const { dbMessages, editedIndices } = truncateToolCallOutputs( - orderedWithInstructions, + orderedMessages, this.maxContextTokens, this.getTokenCountForMessage.bind(this), ); @@ -416,14 +440,19 @@ class BaseClient { if (editedIndices.length > 0) { logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices); for (const index of editedIndices) { - payload[index].content = dbMessages[index].content; + formattedMessages[index].content = dbMessages[index].content; } - orderedWithInstructions = dbMessages; + orderedMessages = dbMessages; } } + let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); + let { context, remainingContextTokens, messagesToRefine, summaryIndex } = - await this.getMessagesWithinTokenLimit(orderedWithInstructions); + await this.getMessagesWithinTokenLimit({ + messages: orderedWithInstructions, + instructions, + }); logger.debug('[BaseClient] Context Count (1/2)', { remainingContextTokens, @@ -435,7 +464,9 @@ class BaseClient { let { shouldSummarize } = this; // Calculate the difference in length to determine how many messages were discarded if any - const { length } = payload; + let payload; + let { length } = formattedMessages; + length += instructions != null ? 1 : 0; const diff = length - context.length; const firstMessage = orderedWithInstructions[0]; const usePrevSummary = @@ -445,18 +476,31 @@ class BaseClient { this.previous_summary.messageId === firstMessage.messageId; if (diff > 0) { - payload = payload.slice(diff); + payload = formattedMessages.slice(diff); logger.debug( `[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`, ); } + payload = this.addInstructions(payload ?? formattedMessages, _instructions); + const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1]; if (payload.length === 0 && !shouldSummarize && latestMessage) { const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`; const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; logger.warn(`Prompt token count exceeds max token count (${info}).`); throw new Error(errorMessage); + } else if ( + _instructions && + payload.length === 1 && + payload[0].content === _instructions.content + ) { + const info = `${tokenCount + 3} / ${this.maxContextTokens}`; + const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; + logger.warn( + `Including instructions, the prompt token count exceeds remaining max token count (${info}).`, + ); + throw new Error(errorMessage); } if (usePrevSummary) { diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index d084f221876..89b938b8582 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -931,7 +931,10 @@ ${convo} ); if (excessTokenCount > maxContextTokens) { - ({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens)); + ({ context } = await this.getMessagesWithinTokenLimit({ + messages: context, + maxContextTokens, + })); } if (context.length === 0) { diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index af8f28a2c21..e899449fb90 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -159,7 +159,7 @@ describe('BaseClient', () => { expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {}; const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content); - const result = await TestClient.getMessagesWithinTokenLimit(messages); + const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); expect(result.summaryIndex).toEqual(expectedIndex); @@ -195,7 +195,7 @@ describe('BaseClient', () => { expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {}; const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content); - const result = await TestClient.getMessagesWithinTokenLimit(messages); + const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); expect(result.summaryIndex).toEqual(expectedIndex); @@ -203,66 +203,6 @@ describe('BaseClient', () => { expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); - test('handles context strategy correctly in handleContextStrategy()', async () => { - TestClient.addInstructions = jest - .fn() - .mockReturnValue([ - { content: 'Hello' }, - { content: 'How can I help you?' }, - { content: 'Please provide more details.' }, - { content: 'I can assist you with that.' }, - ]); - TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({ - context: [ - { content: 'How can I help you?' }, - { content: 'Please provide more details.' }, - { content: 'I can assist you with that.' }, - ], - remainingContextTokens: 80, - messagesToRefine: [{ content: 'Hello' }], - summaryIndex: 3, - }); - - TestClient.getTokenCount = jest.fn().mockReturnValue(40); - - const instructions = { content: 'Please provide more details.' }; - const orderedMessages = [ - { content: 'Hello' }, - { content: 'How can I help you?' }, - { content: 'Please provide more details.' }, - { content: 'I can assist you with that.' }, - ]; - const formattedMessages = [ - { content: 'Hello' }, - { content: 'How can I help you?' }, - { content: 'Please provide more details.' }, - { content: 'I can assist you with that.' }, - ]; - const expectedResult = { - payload: [ - { - role: 'system', - content: 'Refined answer', - }, - { content: 'How can I help you?' }, - { content: 'Please provide more details.' }, - { content: 'I can assist you with that.' }, - ], - promptTokens: expect.any(Number), - tokenCountMap: {}, - messages: expect.any(Array), - }; - - TestClient.shouldSummarize = true; - const result = await TestClient.handleContextStrategy({ - instructions, - orderedMessages, - formattedMessages, - }); - - expect(result).toEqual(expectedResult); - }); - describe('getMessagesForConversation', () => { it('should return an empty array if the parentMessageId does not exist', () => { const result = TestClient.constructor.getMessagesForConversation({ @@ -674,4 +614,112 @@ describe('BaseClient', () => { expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message }); }); + + describe('getMessagesWithinTokenLimit with instructions', () => { + test('should always include instructions when present', async () => { + TestClient.maxContextTokens = 50; + const instructions = { + role: 'system', + content: 'System instructions', + tokenCount: 20, + }; + + const messages = [ + instructions, + { role: 'user', content: 'Hello', tokenCount: 10 }, + { role: 'assistant', content: 'Hi there', tokenCount: 15 }, + ]; + + const result = await TestClient.getMessagesWithinTokenLimit({ + messages, + instructions, + }); + + expect(result.context[0]).toBe(instructions); + expect(result.remainingContextTokens).toBe(2); + }); + + test('should handle case when messages exceed limit but instructions must be preserved', async () => { + TestClient.maxContextTokens = 30; + const instructions = { + role: 'system', + content: 'System instructions', + tokenCount: 20, + }; + + const messages = [ + instructions, + { role: 'user', content: 'Hello', tokenCount: 10 }, + { role: 'assistant', content: 'Hi there', tokenCount: 15 }, + ]; + + const result = await TestClient.getMessagesWithinTokenLimit({ + messages, + instructions, + }); + + // Should only include instructions and the last message that fits + expect(result.context).toHaveLength(1); + expect(result.context[0].content).toBe(instructions.content); + expect(result.messagesToRefine).toHaveLength(2); + expect(result.remainingContextTokens).toBe(7); // 30 - 20 - 3 (assistant label) + }); + + test('should work correctly without instructions (1/2)', async () => { + TestClient.maxContextTokens = 50; + const messages = [ + { role: 'user', content: 'Hello', tokenCount: 10 }, + { role: 'assistant', content: 'Hi there', tokenCount: 15 }, + ]; + + const result = await TestClient.getMessagesWithinTokenLimit({ + messages, + }); + + expect(result.context).toHaveLength(2); + expect(result.remainingContextTokens).toBe(22); // 50 - 10 - 15 - 3(assistant label) + expect(result.messagesToRefine).toHaveLength(0); + }); + + test('should work correctly without instructions (2/2)', async () => { + TestClient.maxContextTokens = 30; + const messages = [ + { role: 'user', content: 'Hello', tokenCount: 10 }, + { role: 'assistant', content: 'Hi there', tokenCount: 20 }, + ]; + + const result = await TestClient.getMessagesWithinTokenLimit({ + messages, + }); + + expect(result.context).toHaveLength(1); + expect(result.remainingContextTokens).toBe(7); + expect(result.messagesToRefine).toHaveLength(1); + }); + + test('should handle case when only instructions fit within limit', async () => { + TestClient.maxContextTokens = 25; + const instructions = { + role: 'system', + content: 'System instructions', + tokenCount: 20, + }; + + const messages = [ + instructions, + { role: 'user', content: 'Hello', tokenCount: 10 }, + { role: 'assistant', content: 'Hi there', tokenCount: 15 }, + ]; + + const result = await TestClient.getMessagesWithinTokenLimit({ + messages, + instructions, + }); + + expect(result.context).toHaveLength(1); + expect(result.context[0]).toBe(instructions); + expect(result.messagesToRefine).toHaveLength(2); + expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label) + }); + }); }); diff --git a/api/server/controllers/auth/LogoutController.js b/api/server/controllers/auth/LogoutController.js index b09b8722aa1..7d010f08593 100644 --- a/api/server/controllers/auth/LogoutController.js +++ b/api/server/controllers/auth/LogoutController.js @@ -5,7 +5,7 @@ const { logger } = require('~/config'); const logoutController = async (req, res) => { const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null; try { - const logout = await logoutUser(req.user._id, refreshToken); + const logout = await logoutUser(req, refreshToken); const { status, message } = logout; res.clearCookie('refreshToken'); return res.status(status).send({ message }); diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 680a5066029..3c02b7eea02 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -35,13 +35,14 @@ const genericVerificationMessage = 'Please check your email to verify your email /** * Logout user * - * @param {String} userId - * @param {*} refreshToken + * @param {ServerRequest} req + * @param {string} refreshToken * @returns */ -const logoutUser = async (userId, refreshToken) => { +const logoutUser = async (req, refreshToken) => { try { - const session = await findSession({ userId: userId, refreshToken: refreshToken }); + const userId = req.user._id; + const session = await findSession({ userId: userId, refreshToken }); if (session) { try { @@ -52,6 +53,12 @@ const logoutUser = async (userId, refreshToken) => { } } + try { + req.session.destroy(); + } catch (destroyErr) { + logger.error('[logoutUser] Failed to destroy session.', destroyErr); + } + return { status: 200, message: 'Logout successful' }; } catch (err) { return { status: 500, message: err.message };