Skip to content

Commit

Permalink
Merge branch 'main' into saved-prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamine authored Jan 28, 2025
2 parents 94c7de5 + 4110209 commit 8ee2fcf
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 83 deletions.
2 changes: 1 addition & 1 deletion api/app/clients/AnthropicClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class AnthropicClient extends BaseClient {
}

let { context: messagesInWindow, remainingContextTokens } =
await this.getMessagesWithinTokenLimit(formattedMessages);
await this.getMessagesWithinTokenLimit({ messages: formattedMessages });

const tokenCountMap = orderedMessages
.slice(orderedMessages.length - messagesInWindow.length)
Expand Down
72 changes: 58 additions & 14 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -347,25 +347,38 @@ class BaseClient {
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
*
* @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
* @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
* @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
* @param {Object} params
* @param {TMessage[]} params.messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
* @param {number} [params.maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
* @param {{ role: 'system', content: text, tokenCount: number }} [params.instructions] - Instructions already added to the context at index 0.
* @returns {Promise<{
* context: TMessage[],
* remainingContextTokens: number,
* messagesToRefine: TMessage[],
* summaryIndex: number,
* }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
* `context` is an array of messages that fit within the token limit.
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
*/
async getMessagesWithinTokenLimit(_messages, maxContextTokens) {
async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) {
// Every reply is primed with <|start|>assistant<|message|>, so we
// start with 3 tokens for the label after all messages have been counted.
let currentTokenCount = 3;
let summaryIndex = -1;
let remainingContextTokens = maxContextTokens ?? this.maxContextTokens;
let currentTokenCount = 3;
const instructionsTokenCount = instructions?.tokenCount ?? 0;
let remainingContextTokens =
(maxContextTokens ?? this.maxContextTokens) - instructionsTokenCount;
const messages = [..._messages];

const context = [];

if (currentTokenCount < remainingContextTokens) {
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
if (messages.length === 1 && instructions) {
break;
}
const poppedMessage = messages.pop();
const { tokenCount } = poppedMessage;

Expand All @@ -379,6 +392,11 @@ class BaseClient {
}
}

if (instructions) {
context.push(_messages[0]);
messages.shift();
}

const prunedMemory = messages;
summaryIndex = prunedMemory.length - 1;
remainingContextTokens -= currentTokenCount;
Expand All @@ -403,27 +421,38 @@ class BaseClient {
if (instructions) {
({ tokenCount, ..._instructions } = instructions);
}

_instructions && logger.debug('[BaseClient] instructions tokenCount: ' + tokenCount);
let payload = this.addInstructions(formattedMessages, _instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
if (tokenCount && tokenCount > this.maxContextTokens) {
const info = `${tokenCount} / ${this.maxContextTokens}`;
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
logger.warn(`Instructions token count exceeds max token count (${info}).`);
throw new Error(errorMessage);
}

if (this.clientName === EModelEndpoint.agents) {
const { dbMessages, editedIndices } = truncateToolCallOutputs(
orderedWithInstructions,
orderedMessages,
this.maxContextTokens,
this.getTokenCountForMessage.bind(this),
);

if (editedIndices.length > 0) {
logger.debug('[BaseClient] Truncated tool call outputs:', editedIndices);
for (const index of editedIndices) {
payload[index].content = dbMessages[index].content;
formattedMessages[index].content = dbMessages[index].content;
}
orderedWithInstructions = dbMessages;
orderedMessages = dbMessages;
}
}

let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);

let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
await this.getMessagesWithinTokenLimit({
messages: orderedWithInstructions,
instructions,
});

logger.debug('[BaseClient] Context Count (1/2)', {
remainingContextTokens,
Expand All @@ -435,7 +464,9 @@ class BaseClient {
let { shouldSummarize } = this;

// Calculate the difference in length to determine how many messages were discarded if any
const { length } = payload;
let payload;
let { length } = formattedMessages;
length += instructions != null ? 1 : 0;
const diff = length - context.length;
const firstMessage = orderedWithInstructions[0];
const usePrevSummary =
Expand All @@ -445,18 +476,31 @@ class BaseClient {
this.previous_summary.messageId === firstMessage.messageId;

if (diff > 0) {
payload = payload.slice(diff);
payload = formattedMessages.slice(diff);
logger.debug(
`[BaseClient] Difference between original payload (${length}) and context (${context.length}): ${diff}`,
);
}

payload = this.addInstructions(payload ?? formattedMessages, _instructions);

const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
if (payload.length === 0 && !shouldSummarize && latestMessage) {
const info = `${latestMessage.tokenCount} / ${this.maxContextTokens}`;
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
logger.warn(`Prompt token count exceeds max token count (${info}).`);
throw new Error(errorMessage);
} else if (
_instructions &&
payload.length === 1 &&
payload[0].content === _instructions.content
) {
const info = `${tokenCount + 3} / ${this.maxContextTokens}`;
const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`;
logger.warn(
`Including instructions, the prompt token count exceeds remaining max token count (${info}).`,
);
throw new Error(errorMessage);
}

if (usePrevSummary) {
Expand Down
5 changes: 4 additions & 1 deletion api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,10 @@ ${convo}
);

if (excessTokenCount > maxContextTokens) {
({ context } = await this.getMessagesWithinTokenLimit(context, maxContextTokens));
({ context } = await this.getMessagesWithinTokenLimit({
messages: context,
maxContextTokens,
}));
}

if (context.length === 0) {
Expand Down
172 changes: 110 additions & 62 deletions api/app/clients/specs/BaseClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ describe('BaseClient', () => {
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);

const result = await TestClient.getMessagesWithinTokenLimit(messages);
const result = await TestClient.getMessagesWithinTokenLimit({ messages });

expect(result.context).toEqual(expectedContext);
expect(result.summaryIndex).toEqual(expectedIndex);
Expand Down Expand Up @@ -195,74 +195,14 @@ describe('BaseClient', () => {
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);

const result = await TestClient.getMessagesWithinTokenLimit(messages);
const result = await TestClient.getMessagesWithinTokenLimit({ messages });

expect(result.context).toEqual(expectedContext);
expect(result.summaryIndex).toEqual(expectedIndex);
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
});

test('handles context strategy correctly in handleContextStrategy()', async () => {
TestClient.addInstructions = jest
.fn()
.mockReturnValue([
{ content: 'Hello' },
{ content: 'How can I help you?' },
{ content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
]);
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({
context: [
{ content: 'How can I help you?' },
{ content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
],
remainingContextTokens: 80,
messagesToRefine: [{ content: 'Hello' }],
summaryIndex: 3,
});

TestClient.getTokenCount = jest.fn().mockReturnValue(40);

const instructions = { content: 'Please provide more details.' };
const orderedMessages = [
{ content: 'Hello' },
{ content: 'How can I help you?' },
{ content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
];
const formattedMessages = [
{ content: 'Hello' },
{ content: 'How can I help you?' },
{ content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
];
const expectedResult = {
payload: [
{
role: 'system',
content: 'Refined answer',
},
{ content: 'How can I help you?' },
{ content: 'Please provide more details.' },
{ content: 'I can assist you with that.' },
],
promptTokens: expect.any(Number),
tokenCountMap: {},
messages: expect.any(Array),
};

TestClient.shouldSummarize = true;
const result = await TestClient.handleContextStrategy({
instructions,
orderedMessages,
formattedMessages,
});

expect(result).toEqual(expectedResult);
});

describe('getMessagesForConversation', () => {
it('should return an empty array if the parentMessageId does not exist', () => {
const result = TestClient.constructor.getMessagesForConversation({
Expand Down Expand Up @@ -674,4 +614,112 @@ describe('BaseClient', () => {
expect(calls[1][0].isCreatedByUser).toBe(false); // Second call should be for response message
});
});

describe('getMessagesWithinTokenLimit with instructions', () => {
test('should always include instructions when present', async () => {
TestClient.maxContextTokens = 50;
const instructions = {
role: 'system',
content: 'System instructions',
tokenCount: 20,
};

const messages = [
instructions,
{ role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
];

const result = await TestClient.getMessagesWithinTokenLimit({
messages,
instructions,
});

expect(result.context[0]).toBe(instructions);
expect(result.remainingContextTokens).toBe(2);
});

test('should handle case when messages exceed limit but instructions must be preserved', async () => {
TestClient.maxContextTokens = 30;
const instructions = {
role: 'system',
content: 'System instructions',
tokenCount: 20,
};

const messages = [
instructions,
{ role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
];

const result = await TestClient.getMessagesWithinTokenLimit({
messages,
instructions,
});

// Should only include instructions and the last message that fits
expect(result.context).toHaveLength(1);
expect(result.context[0].content).toBe(instructions.content);
expect(result.messagesToRefine).toHaveLength(2);
expect(result.remainingContextTokens).toBe(7); // 30 - 20 - 3 (assistant label)
});

test('should work correctly without instructions (1/2)', async () => {
TestClient.maxContextTokens = 50;
const messages = [
{ role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
];

const result = await TestClient.getMessagesWithinTokenLimit({
messages,
});

expect(result.context).toHaveLength(2);
expect(result.remainingContextTokens).toBe(22); // 50 - 10 - 15 - 3(assistant label)
expect(result.messagesToRefine).toHaveLength(0);
});

test('should work correctly without instructions (2/2)', async () => {
TestClient.maxContextTokens = 30;
const messages = [
{ role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'Hi there', tokenCount: 20 },
];

const result = await TestClient.getMessagesWithinTokenLimit({
messages,
});

expect(result.context).toHaveLength(1);
expect(result.remainingContextTokens).toBe(7);
expect(result.messagesToRefine).toHaveLength(1);
});

test('should handle case when only instructions fit within limit', async () => {
TestClient.maxContextTokens = 25;
const instructions = {
role: 'system',
content: 'System instructions',
tokenCount: 20,
};

const messages = [
instructions,
{ role: 'user', content: 'Hello', tokenCount: 10 },
{ role: 'assistant', content: 'Hi there', tokenCount: 15 },
];

const result = await TestClient.getMessagesWithinTokenLimit({
messages,
instructions,
});

expect(result.context).toHaveLength(1);
expect(result.context[0]).toBe(instructions);
expect(result.messagesToRefine).toHaveLength(2);
expect(result.remainingContextTokens).toBe(2); // 25 - 20 - 3(assistant label)
});
});
});
2 changes: 1 addition & 1 deletion api/server/controllers/auth/LogoutController.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const { logger } = require('~/config');
const logoutController = async (req, res) => {
const refreshToken = req.headers.cookie ? cookies.parse(req.headers.cookie).refreshToken : null;
try {
const logout = await logoutUser(req.user._id, refreshToken);
const logout = await logoutUser(req, refreshToken);
const { status, message } = logout;
res.clearCookie('refreshToken');
return res.status(status).send({ message });
Expand Down
Loading

0 comments on commit 8ee2fcf

Please sign in to comment.