diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index d1690672e37..7a4d345d74d 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -569,11 +569,11 @@ class BaseClient { * the message is considered a root message. * * @param {Object} options - The options for the function. - * @param {Array} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property. + * @param {TMessage[]} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property. * @param {string} options.parentMessageId - The ID of the parent message to start the traversal from. * @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array. * @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property. - * @returns {Array} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'. + * @returns {TMessage[]} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'. */ static getMessagesForConversation({ messages, diff --git a/api/models/Conversation.js b/api/models/Conversation.js index e97c0554ff0..4c700425feb 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -2,6 +2,12 @@ const Conversation = require('./schema/convoSchema'); const { getMessages, deleteMessages } = require('./Message'); const logger = require('~/config/winston'); +/** + * Retrieves a single conversation for a given user and conversation ID. + * @param {string} user - The user's ID. + * @param {string} conversationId - The conversation's ID. + * @returns {Promise} The conversation object. + */ const getConvo = async (user, conversationId) => { try { return await Conversation.findOne({ user, conversationId }).lean(); diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 52c906fc6e8..e478e48e5ef 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -6,6 +6,7 @@ const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models const { IMPORT_CONVERSATION_JOB_NAME } = require('~/server/utils/import/jobDefinition'); const { storage, importFileFilter } = require('~/server/routes/files/multer'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); +const { forkConversation } = require('~/server/utils/import/fork'); const { createImportLimiters } = require('~/server/middleware'); const jobScheduler = require('~/server/utils/jobScheduler'); const getLogStores = require('~/cache/getLogStores'); @@ -131,6 +132,35 @@ router.post( }, ); +/** + * POST /fork + * This route handles forking a conversation based on the TForkConvoRequest and responds with TForkConvoResponse. + * @route POST /fork + * @param {express.Request<{}, TForkConvoResponse, TForkConvoRequest>} req - Express request object. + * @param {express.Response} res - Express response object. + * @returns {Promise} - The response after forking the conversation. + */ +router.post('/fork', async (req, res) => { + try { + /** @type {TForkConvoRequest} */ + const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body; + const result = await forkConversation({ + requestUserId: req.user.id, + originalConvoId: conversationId, + targetMessageId: messageId, + latestMessageId, + records: true, + splitAtTarget, + option, + }); + + res.json(result); + } catch (error) { + logger.error('Error forking conversation', error); + res.status(500).send('Error forking conversation'); + } +}); + // Get the status of an import job for polling router.get('/import/jobs/:jobId', async (req, res) => { try { diff --git a/api/server/utils/import/fork.js b/api/server/utils/import/fork.js new file mode 100644 index 00000000000..cb75d7863bb --- /dev/null +++ b/api/server/utils/import/fork.js @@ -0,0 +1,314 @@ +const { v4: uuidv4 } = require('uuid'); +const { EModelEndpoint, Constants, ForkOptions } = require('librechat-data-provider'); +const { createImportBatchBuilder } = require('./importBatchBuilder'); +const BaseClient = require('~/app/clients/BaseClient'); +const { getConvo } = require('~/models/Conversation'); +const { getMessages } = require('~/models/Message'); +const logger = require('~/config/winston'); + +/** + * + * @param {object} params - The parameters for the importer. + * @param {string} params.originalConvoId - The ID of the conversation to fork. + * @param {string} params.targetMessageId - The ID of the message to fork from. + * @param {string} params.requestUserId - The ID of the user making the request. + * @param {string} [params.newTitle] - Optional new title for the forked conversation uses old title if not provided + * @param {string} [params.option=''] - Optional flag for fork option + * @param {boolean} [params.records=false] - Optional flag for returning actual database records or resulting conversation and messages. + * @param {boolean} [params.splitAtTarget=false] - Optional flag for splitting the messages at the target message level. + * @param {string} [params.latestMessageId] - latestMessageId - Required if splitAtTarget is true. + * @param {(userId: string) => ImportBatchBuilder} [params.builderFactory] - Optional factory function for creating an ImportBatchBuilder instance. + * @returns {Promise} The response after forking the conversation. + */ +async function forkConversation({ + originalConvoId, + targetMessageId: targetId, + requestUserId, + newTitle, + option = ForkOptions.TARGET_LEVEL, + records = false, + splitAtTarget = false, + latestMessageId, + builderFactory = createImportBatchBuilder, +}) { + try { + const originalConvo = await getConvo(requestUserId, originalConvoId); + let originalMessages = await getMessages({ + user: requestUserId, + conversationId: originalConvoId, + }); + + let targetMessageId = targetId; + if (splitAtTarget && !latestMessageId) { + throw new Error('Latest `messageId` is required for forking from target message.'); + } else if (splitAtTarget) { + originalMessages = splitAtTargetLevel(originalMessages, targetId); + targetMessageId = latestMessageId; + } + + const importBatchBuilder = builderFactory(requestUserId); + importBatchBuilder.startConversation(originalConvo.endpoint ?? EModelEndpoint.openAI); + + let messagesToClone = []; + + if (option === ForkOptions.DIRECT_PATH) { + // Direct path only + messagesToClone = BaseClient.getMessagesForConversation({ + messages: originalMessages, + parentMessageId: targetMessageId, + }); + } else if (option === ForkOptions.INCLUDE_BRANCHES) { + // Direct path and siblings + messagesToClone = getAllMessagesUpToParent(originalMessages, targetMessageId); + } else if (option === ForkOptions.TARGET_LEVEL || !option) { + // Direct path, siblings, and all descendants + messagesToClone = getMessagesUpToTargetLevel(originalMessages, targetMessageId); + } + + const idMapping = new Map(); + + for (const message of messagesToClone) { + const newMessageId = uuidv4(); + idMapping.set(message.messageId, newMessageId); + + const clonedMessage = { + ...message, + messageId: newMessageId, + parentMessageId: + message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT + ? idMapping.get(message.parentMessageId) + : Constants.NO_PARENT, + }; + + importBatchBuilder.saveMessage(clonedMessage); + } + + const result = importBatchBuilder.finishConversation( + newTitle || originalConvo.title, + new Date(), + originalConvo, + ); + await importBatchBuilder.saveBatch(); + logger.debug( + `user: ${requestUserId} | New conversation "${ + newTitle || originalConvo.title + }" forked from conversation ID ${originalConvoId}`, + ); + + if (!records) { + return result; + } + + const conversation = await getConvo(requestUserId, result.conversation.conversationId); + const messages = await getMessages({ + user: requestUserId, + conversationId: conversation.conversationId, + }); + + return { + conversation, + messages, + }; + } catch (error) { + logger.error( + `user: ${requestUserId} | Error forking conversation from original ID ${originalConvoId}`, + error, + ); + throw error; + } +} + +/** + * Retrieves all messages up to the root from the target message. + * @param {TMessage[]} messages - The list of messages to search. + * @param {string} targetMessageId - The ID of the target message. + * @returns {TMessage[]} The list of messages up to the root from the target message. + */ +function getAllMessagesUpToParent(messages, targetMessageId) { + const targetMessage = messages.find((msg) => msg.messageId === targetMessageId); + if (!targetMessage) { + return []; + } + + const pathToRoot = new Set(); + const visited = new Set(); + let current = targetMessage; + + while (current) { + if (visited.has(current.messageId)) { + break; + } + + visited.add(current.messageId); + pathToRoot.add(current.messageId); + + const currentParentId = current.parentMessageId ?? Constants.NO_PARENT; + if (currentParentId === Constants.NO_PARENT) { + break; + } + + current = messages.find((msg) => msg.messageId === currentParentId); + } + + // Include all messages that are in the path or whose parent is in the path + // Exclude children of the target message + return messages.filter( + (msg) => + (pathToRoot.has(msg.messageId) && msg.messageId !== targetMessageId) || + (pathToRoot.has(msg.parentMessageId) && msg.parentMessageId !== targetMessageId) || + msg.messageId === targetMessageId, + ); +} + +/** + * Retrieves all messages up to the root from the target message and its neighbors. + * @param {TMessage[]} messages - The list of messages to search. + * @param {string} targetMessageId - The ID of the target message. + * @returns {TMessage[]} The list of inclusive messages up to the root from the target message. + */ +function getMessagesUpToTargetLevel(messages, targetMessageId) { + if (messages.length === 1 && messages[0] && messages[0].messageId === targetMessageId) { + return messages; + } + + // Create a map of parentMessageId to children messages + const parentToChildrenMap = new Map(); + for (const message of messages) { + if (!parentToChildrenMap.has(message.parentMessageId)) { + parentToChildrenMap.set(message.parentMessageId, []); + } + parentToChildrenMap.get(message.parentMessageId).push(message); + } + + // Retrieve the target message + const targetMessage = messages.find((msg) => msg.messageId === targetMessageId); + if (!targetMessage) { + logger.error('Target message not found.'); + return []; + } + + const visited = new Set(); + + const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || []; + let currentLevel = rootMessages.length > 0 ? [...rootMessages] : [targetMessage]; + const results = new Set(currentLevel); + + // Check if the target message is at the root level + if ( + currentLevel.some((msg) => msg.messageId === targetMessageId) && + targetMessage.parentMessageId === Constants.NO_PARENT + ) { + return Array.from(results); + } + + // Iterate level by level until the target is found + let targetFound = false; + while (!targetFound && currentLevel.length > 0) { + const nextLevel = []; + for (const node of currentLevel) { + if (visited.has(node.messageId)) { + logger.warn('Cycle detected in message tree'); + continue; + } + visited.add(node.messageId); + const children = parentToChildrenMap.get(node.messageId) || []; + for (const child of children) { + if (visited.has(child.messageId)) { + logger.warn('Cycle detected in message tree'); + continue; + } + nextLevel.push(child); + results.add(child); + if (child.messageId === targetMessageId) { + targetFound = true; + } + } + } + currentLevel = nextLevel; + } + + return Array.from(results); +} + +/** + * Splits the conversation at the targeted message level, including the target, its siblings, and all descendant messages. + * All target level messages have their parentMessageId set to the root. + * @param {TMessage[]} messages - The list of messages to analyze. + * @param {string} targetMessageId - The ID of the message to start the split from. + * @returns {TMessage[]} The list of messages at and below the target level. + */ +function splitAtTargetLevel(messages, targetMessageId) { + // Create a map of parentMessageId to children messages + const parentToChildrenMap = new Map(); + for (const message of messages) { + if (!parentToChildrenMap.has(message.parentMessageId)) { + parentToChildrenMap.set(message.parentMessageId, []); + } + parentToChildrenMap.get(message.parentMessageId).push(message); + } + + // Retrieve the target message + const targetMessage = messages.find((msg) => msg.messageId === targetMessageId); + if (!targetMessage) { + logger.error('Target message not found.'); + return []; + } + + // Initialize the search with root messages + const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || []; + let currentLevel = [...rootMessages]; + let currentLevelIndex = 0; + const levelMap = {}; + + // Map messages to their levels + rootMessages.forEach((msg) => { + levelMap[msg.messageId] = 0; + }); + + // Search for the target level + while (currentLevel.length > 0) { + const nextLevel = []; + for (const node of currentLevel) { + const children = parentToChildrenMap.get(node.messageId) || []; + for (const child of children) { + nextLevel.push(child); + levelMap[child.messageId] = currentLevelIndex + 1; + } + } + currentLevel = nextLevel; + currentLevelIndex++; + } + + // Determine the target level + const targetLevel = levelMap[targetMessageId]; + if (targetLevel === undefined) { + logger.error('Target level not found.'); + return []; + } + + // Filter messages at or below the target level + const filteredMessages = messages + .map((msg) => { + const messageLevel = levelMap[msg.messageId]; + if (messageLevel < targetLevel) { + return null; + } else if (messageLevel === targetLevel) { + return { + ...msg, + parentMessageId: Constants.NO_PARENT, + }; + } + + return msg; + }) + .filter((msg) => msg !== null); + + return filteredMessages; +} + +module.exports = { + forkConversation, + splitAtTargetLevel, + getAllMessagesUpToParent, + getMessagesUpToTargetLevel, +}; diff --git a/api/server/utils/import/fork.spec.js b/api/server/utils/import/fork.spec.js new file mode 100644 index 00000000000..f4f4a2b81ee --- /dev/null +++ b/api/server/utils/import/fork.spec.js @@ -0,0 +1,574 @@ +const { Constants, ForkOptions } = require('librechat-data-provider'); + +jest.mock('~/models/Conversation', () => ({ + getConvo: jest.fn(), + bulkSaveConvos: jest.fn(), +})); + +jest.mock('~/models/Message', () => ({ + getMessages: jest.fn(), + bulkSaveMessages: jest.fn(), +})); + +let mockIdCounter = 0; +jest.mock('uuid', () => { + return { + v4: jest.fn(() => { + mockIdCounter++; + return mockIdCounter.toString(); + }), + }; +}); + +const { + forkConversation, + splitAtTargetLevel, + getAllMessagesUpToParent, + getMessagesUpToTargetLevel, +} = require('./fork'); +const { getConvo, bulkSaveConvos } = require('~/models/Conversation'); +const { getMessages, bulkSaveMessages } = require('~/models/Message'); +const BaseClient = require('~/app/clients/BaseClient'); + +/** + * + * @param {TMessage[]} messages - The list of messages to visualize. + * @param {string | null} parentId - The parent message ID. + * @param {string} prefix - The prefix to use for each line. + * @returns + */ +function printMessageTree(messages, parentId = Constants.NO_PARENT, prefix = '') { + let treeVisual = ''; + + const childMessages = messages.filter((msg) => msg.parentMessageId === parentId); + for (let index = 0; index < childMessages.length; index++) { + const msg = childMessages[index]; + const isLast = index === childMessages.length - 1; + const connector = isLast ? '└── ' : '├── '; + + treeVisual += `${prefix}${connector}[${msg.messageId}]: ${ + msg.parentMessageId !== Constants.NO_PARENT ? `Child of ${msg.parentMessageId}` : 'Root' + }\n`; + treeVisual += printMessageTree(messages, msg.messageId, prefix + (isLast ? ' ' : '| ')); + } + + return treeVisual; +} + +const mockMessages = [ + { + messageId: '0', + parentMessageId: Constants.NO_PARENT, + text: 'Root message 1', + createdAt: '2021-01-01', + }, + { + messageId: '1', + parentMessageId: Constants.NO_PARENT, + text: 'Root message 2', + createdAt: '2021-01-01', + }, + { messageId: '2', parentMessageId: '1', text: 'Child of 1', createdAt: '2021-01-02' }, + { messageId: '3', parentMessageId: '1', text: 'Child of 1', createdAt: '2021-01-03' }, + { messageId: '4', parentMessageId: '2', text: 'Child of 2', createdAt: '2021-01-04' }, + { messageId: '5', parentMessageId: '2', text: 'Child of 2', createdAt: '2021-01-05' }, + { messageId: '6', parentMessageId: '3', text: 'Child of 3', createdAt: '2021-01-06' }, + { messageId: '7', parentMessageId: '3', text: 'Child of 3', createdAt: '2021-01-07' }, + { messageId: '8', parentMessageId: '7', text: 'Child of 7', createdAt: '2021-01-07' }, +]; + +const mockConversation = { convoId: 'abc123', title: 'Original Title' }; + +describe('forkConversation', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockIdCounter = 0; + getConvo.mockResolvedValue(mockConversation); + getMessages.mockResolvedValue(mockMessages); + bulkSaveConvos.mockResolvedValue(null); + bulkSaveMessages.mockResolvedValue(null); + }); + + test('should fork conversation without branches', async () => { + const result = await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.DIRECT_PATH, + }); + console.debug('forkConversation: direct path\n', printMessageTree(result.messages)); + + // Reversed order due to setup in function + const expectedMessagesTexts = ['Child of 1', 'Root message 2']; + expect(getMessages).toHaveBeenCalled(); + expect(bulkSaveMessages).toHaveBeenCalledWith( + expect.arrayContaining( + expectedMessagesTexts.map((text) => expect.objectContaining({ text })), + ), + ); + }); + + test('should fork conversation without branches (deeper)', async () => { + const result = await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '8', + requestUserId: 'user1', + option: ForkOptions.DIRECT_PATH, + }); + console.debug('forkConversation: direct path (deeper)\n', printMessageTree(result.messages)); + + const expectedMessagesTexts = ['Child of 7', 'Child of 3', 'Child of 1', 'Root message 2']; + expect(getMessages).toHaveBeenCalled(); + expect(bulkSaveMessages).toHaveBeenCalledWith( + expect.arrayContaining( + expectedMessagesTexts.map((text) => expect.objectContaining({ text })), + ), + ); + }); + + test('should fork conversation with branches', async () => { + const result = await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.INCLUDE_BRANCHES, + }); + + console.debug('forkConversation: include branches\n', printMessageTree(result.messages)); + + const expectedMessagesTexts = ['Root message 2', 'Child of 1', 'Child of 1']; + expect(getMessages).toHaveBeenCalled(); + expect(bulkSaveMessages).toHaveBeenCalledWith( + expect.arrayContaining( + expectedMessagesTexts.map((text) => expect.objectContaining({ text })), + ), + ); + }); + + test('should fork conversation up to target level', async () => { + const result = await forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + option: ForkOptions.TARGET_LEVEL, + }); + + console.debug('forkConversation: target level\n', printMessageTree(result.messages)); + + const expectedMessagesTexts = ['Root message 1', 'Root message 2', 'Child of 1', 'Child of 1']; + expect(getMessages).toHaveBeenCalled(); + expect(bulkSaveMessages).toHaveBeenCalledWith( + expect.arrayContaining( + expectedMessagesTexts.map((text) => expect.objectContaining({ text })), + ), + ); + }); + + test('should handle errors during message fetching', async () => { + getMessages.mockRejectedValue(new Error('Failed to fetch messages')); + + await expect( + forkConversation({ + originalConvoId: 'abc123', + targetMessageId: '3', + requestUserId: 'user1', + }), + ).rejects.toThrow('Failed to fetch messages'); + }); +}); + +const mockMessagesComplex = [ + { messageId: '7', parentMessageId: Constants.NO_PARENT, text: 'Message 7' }, + { messageId: '8', parentMessageId: Constants.NO_PARENT, text: 'Message 8' }, + { messageId: '5', parentMessageId: '7', text: 'Message 5' }, + { messageId: '6', parentMessageId: '7', text: 'Message 6' }, + { messageId: '9', parentMessageId: '8', text: 'Message 9' }, + { messageId: '2', parentMessageId: '5', text: 'Message 2' }, + { messageId: '3', parentMessageId: '5', text: 'Message 3' }, + { messageId: '1', parentMessageId: '6', text: 'Message 1' }, + { messageId: '4', parentMessageId: '6', text: 'Message 4' }, + { messageId: '10', parentMessageId: '3', text: 'Message 10' }, +]; + +describe('getMessagesUpToTargetLevel', () => { + test('should get all messages up to target level', async () => { + const result = getMessagesUpToTargetLevel(mockMessagesComplex, '5'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesUpToTargetLevel] should get all messages up to target level\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessagesComplex)); + console.debug('result\n', printMessageTree(result)); + expect(mappedResult).toEqual(['7', '8', '5', '6', '9']); + }); + + test('should get all messages if target is deepest level', async () => { + const result = getMessagesUpToTargetLevel(mockMessagesComplex, '10'); + expect(result.length).toEqual(mockMessagesComplex.length); + }); + + test('should return target if only message', async () => { + const result = getMessagesUpToTargetLevel( + [mockMessagesComplex[mockMessagesComplex.length - 1]], + '10', + ); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesUpToTargetLevel] should return target if only message\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessages)); + console.debug('result\n', printMessageTree(result)); + expect(mappedResult).toEqual(['10']); + }); + + test('should return empty array if target message ID does not exist', async () => { + const result = getMessagesUpToTargetLevel(mockMessagesComplex, '123'); + expect(result).toEqual([]); + }); + + test('should return correct messages when target is a root message', async () => { + const result = getMessagesUpToTargetLevel(mockMessagesComplex, '7'); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toEqual(['7', '8']); + }); + + test('should correctly handle single message with non-matching ID', async () => { + const singleMessage = [ + { messageId: '30', parentMessageId: Constants.NO_PARENT, text: 'Message 30' }, + ]; + const result = getMessagesUpToTargetLevel(singleMessage, '31'); + expect(result).toEqual([]); + }); + + test('should correctly handle case with circular dependencies', async () => { + const circularMessages = [ + { messageId: '40', parentMessageId: '42', text: 'Message 40' }, + { messageId: '41', parentMessageId: '40', text: 'Message 41' }, + { messageId: '42', parentMessageId: '41', text: 'Message 42' }, + ]; + const result = getMessagesUpToTargetLevel(circularMessages, '40'); + const mappedResult = result.map((msg) => msg.messageId); + expect(new Set(mappedResult)).toEqual(new Set(['40', '41', '42'])); + }); + + test('should return all messages when all are interconnected and target is deep in hierarchy', async () => { + const interconnectedMessages = [ + { messageId: '50', parentMessageId: Constants.NO_PARENT, text: 'Root Message' }, + { messageId: '51', parentMessageId: '50', text: 'Child Level 1' }, + { messageId: '52', parentMessageId: '51', text: 'Child Level 2' }, + { messageId: '53', parentMessageId: '52', text: 'Child Level 3' }, + ]; + const result = getMessagesUpToTargetLevel(interconnectedMessages, '53'); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toEqual(['50', '51', '52', '53']); + }); +}); + +describe('getAllMessagesUpToParent', () => { + const mockMessages = [ + { messageId: '11', parentMessageId: Constants.NO_PARENT, text: 'Message 11' }, + { messageId: '12', parentMessageId: Constants.NO_PARENT, text: 'Message 12' }, + { messageId: '13', parentMessageId: '11', text: 'Message 13' }, + { messageId: '14', parentMessageId: '12', text: 'Message 14' }, + { messageId: '15', parentMessageId: '13', text: 'Message 15' }, + { messageId: '16', parentMessageId: '13', text: 'Message 16' }, + { messageId: '21', parentMessageId: '13', text: 'Message 21' }, + { messageId: '17', parentMessageId: '14', text: 'Message 17' }, + { messageId: '18', parentMessageId: '16', text: 'Message 18' }, + { messageId: '19', parentMessageId: '18', text: 'Message 19' }, + { messageId: '20', parentMessageId: '19', text: 'Message 20' }, + ]; + + test('should handle empty message list', async () => { + const result = getAllMessagesUpToParent([], '10'); + expect(result).toEqual([]); + }); + + test('should handle target message not found', async () => { + const result = getAllMessagesUpToParent(mockMessages, 'invalid-id'); + expect(result).toEqual([]); + }); + + test('should handle single level tree (no parents)', async () => { + const result = getAllMessagesUpToParent( + [ + { messageId: '11', parentMessageId: Constants.NO_PARENT, text: 'Message 11' }, + { messageId: '12', parentMessageId: Constants.NO_PARENT, text: 'Message 12' }, + ], + '11', + ); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toEqual(['11']); + }); + + test('should correctly retrieve messages in a deeply nested structure', async () => { + const result = getAllMessagesUpToParent(mockMessages, '20'); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toContain('11'); + expect(mappedResult).toContain('13'); + expect(mappedResult).toContain('16'); + expect(mappedResult).toContain('18'); + expect(mappedResult).toContain('19'); + expect(mappedResult).toContain('20'); + }); + + test('should return only the target message if it has no parent', async () => { + const result = getAllMessagesUpToParent(mockMessages, '11'); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toEqual(['11']); + }); + + test('should handle messages without a parent ID defined', async () => { + const additionalMessages = [ + ...mockMessages, + { messageId: '22', text: 'Message 22' }, // No parentMessageId field + ]; + const result = getAllMessagesUpToParent(additionalMessages, '22'); + const mappedResult = result.map((msg) => msg.messageId); + expect(mappedResult).toEqual(['22']); + }); + + test('should retrieve all messages from the target to the root (including indirect ancestors)', async () => { + const result = getAllMessagesUpToParent(mockMessages, '18'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getAllMessagesUpToParent] should retrieve all messages from the target to the root\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessages)); + console.debug('result\n', printMessageTree(result)); + expect(mappedResult).toEqual(['11', '13', '15', '16', '21', '18']); + }); + + test('should handle circular dependencies gracefully', () => { + const mockMessages = [ + { messageId: '1', parentMessageId: '2' }, + { messageId: '2', parentMessageId: '3' }, + { messageId: '3', parentMessageId: '1' }, + ]; + + const targetMessageId = '1'; + const result = getAllMessagesUpToParent(mockMessages, targetMessageId); + + const uniqueIds = new Set(result.map((msg) => msg.messageId)); + expect(uniqueIds.size).toBe(result.length); + expect(result.map((msg) => msg.messageId).sort()).toEqual(['1', '2', '3'].sort()); + }); + + test('should return target if only message', async () => { + const result = getAllMessagesUpToParent([mockMessages[mockMessages.length - 1]], '20'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getAllMessagesUpToParent] should return target if only message\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessages)); + console.debug('result\n', printMessageTree(result)); + expect(mappedResult).toEqual(['20']); + }); +}); + +describe('getMessagesForConversation', () => { + const mockMessages = [ + { messageId: '11', parentMessageId: Constants.NO_PARENT, text: 'Message 11' }, + { messageId: '12', parentMessageId: Constants.NO_PARENT, text: 'Message 12' }, + { messageId: '13', parentMessageId: '11', text: 'Message 13' }, + { messageId: '14', parentMessageId: '12', text: 'Message 14' }, + { messageId: '15', parentMessageId: '13', text: 'Message 15' }, + { messageId: '16', parentMessageId: '13', text: 'Message 16' }, + { messageId: '21', parentMessageId: '13', text: 'Message 21' }, + { messageId: '17', parentMessageId: '14', text: 'Message 17' }, + { messageId: '18', parentMessageId: '16', text: 'Message 18' }, + { messageId: '19', parentMessageId: '18', text: 'Message 19' }, + { messageId: '20', parentMessageId: '19', text: 'Message 20' }, + ]; + + test('should provide the direct path to the target without branches', async () => { + const result = BaseClient.getMessagesForConversation({ + messages: mockMessages, + parentMessageId: '18', + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should provide the direct path to the target without branches\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessages)); + console.debug('result\n', printMessageTree(result)); + expect(new Set(mappedResult)).toEqual(new Set(['11', '13', '16', '18'])); + }); + + test('should return target if only message', async () => { + const result = BaseClient.getMessagesForConversation({ + messages: [mockMessages[mockMessages.length - 1]], + parentMessageId: '20', + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should return target if only message\n', + mappedResult, + ); + console.debug('mockMessages\n', printMessageTree(mockMessages)); + console.debug('result\n', printMessageTree(result)); + expect(new Set(mappedResult)).toEqual(new Set(['20'])); + }); + + test('should break on detecting a circular dependency', async () => { + const mockMessagesWithCycle = [ + ...mockMessagesComplex, + { messageId: '100', parentMessageId: '101', text: 'Message 100' }, + { messageId: '101', parentMessageId: '100', text: 'Message 101' }, // introduces circular dependency + ]; + + const result = BaseClient.getMessagesForConversation({ + messages: mockMessagesWithCycle, + parentMessageId: '100', + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should break on detecting a circular dependency\n', + mappedResult, + ); + expect(mappedResult).toEqual(['101', '100']); + }); + + // Testing with mockMessagesComplex + test('should correctly find the conversation path including root messages', async () => { + const result = BaseClient.getMessagesForConversation({ + messages: mockMessagesComplex, + parentMessageId: '2', + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should correctly find the conversation path including root messages\n', + mappedResult, + ); + expect(new Set(mappedResult)).toEqual(new Set(['7', '5', '2'])); + }); + + // Testing summary feature + test('should stop at summary if option is enabled', async () => { + const messagesWithSummary = [ + ...mockMessagesComplex, + { messageId: '11', parentMessageId: '7', text: 'Message 11', summary: 'Summary for 11' }, + ]; + + const result = BaseClient.getMessagesForConversation({ + messages: messagesWithSummary, + parentMessageId: '11', + summary: true, + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should stop at summary if option is enabled\n', + mappedResult, + ); + expect(mappedResult).toEqual(['11']); // Should include only the summarizing message + }); + + // Testing no parent condition + test('should return only the root message if no parent exists', async () => { + const result = BaseClient.getMessagesForConversation({ + messages: mockMessagesComplex, + parentMessageId: '8', + }); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + '[getMessagesForConversation] should return only the root message if no parent exists\n', + mappedResult, + ); + expect(mappedResult).toEqual(['8']); // The message with no parent in the thread + }); +}); + +describe('splitAtTargetLevel', () => { + /* const mockMessagesComplex = [ + { messageId: '7', parentMessageId: Constants.NO_PARENT, text: 'Message 7' }, + { messageId: '8', parentMessageId: Constants.NO_PARENT, text: 'Message 8' }, + { messageId: '5', parentMessageId: '7', text: 'Message 5' }, + { messageId: '6', parentMessageId: '7', text: 'Message 6' }, + { messageId: '9', parentMessageId: '8', text: 'Message 9' }, + { messageId: '2', parentMessageId: '5', text: 'Message 2' }, + { messageId: '3', parentMessageId: '5', text: 'Message 3' }, + { messageId: '1', parentMessageId: '6', text: 'Message 1' }, + { messageId: '4', parentMessageId: '6', text: 'Message 4' }, + { messageId: '10', parentMessageId: '3', text: 'Message 10' }, + ]; + + mockMessages + ├── [7]: Root + | ├── [5]: Child of 7 + | | ├── [2]: Child of 5 + | | └── [3]: Child of 5 + | | └── [10]: Child of 3 + | └── [6]: Child of 7 + | ├── [1]: Child of 6 + | └── [4]: Child of 6 + └── [8]: Root + └── [9]: Child of 8 + */ + test('should include target message level and all descendants (1/2)', () => { + console.debug('splitAtTargetLevel: mockMessages\n', printMessageTree(mockMessagesComplex)); + const result = splitAtTargetLevel(mockMessagesComplex, '2'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + 'splitAtTargetLevel: include target message level and all descendants (1/2)\n', + printMessageTree(result), + ); + expect(mappedResult).toEqual(['2', '3', '1', '4', '10']); + }); + + test('should include target message level and all descendants (2/2)', () => { + console.debug('splitAtTargetLevel: mockMessages\n', printMessageTree(mockMessagesComplex)); + const result = splitAtTargetLevel(mockMessagesComplex, '5'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + 'splitAtTargetLevel: include target message level and all descendants (2/2)\n', + printMessageTree(result), + ); + expect(mappedResult).toEqual(['5', '6', '9', '2', '3', '1', '4', '10']); + }); + + test('should handle when target message is root', () => { + const result = splitAtTargetLevel(mockMessagesComplex, '7'); + console.debug('splitAtTargetLevel: target level is root message\n', printMessageTree(result)); + expect(result.length).toBe(mockMessagesComplex.length); + }); + + test('should handle when target message is deepest, lonely child', () => { + const result = splitAtTargetLevel(mockMessagesComplex, '10'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + 'splitAtTargetLevel: target message is deepest, lonely child\n', + printMessageTree(result), + ); + expect(mappedResult).toEqual(['10']); + }); + + test('should handle when target level is last with many neighbors', () => { + const mockMessages = [ + ...mockMessagesComplex, + { messageId: '11', parentMessageId: '10', text: 'Message 11' }, + { messageId: '12', parentMessageId: '10', text: 'Message 12' }, + { messageId: '13', parentMessageId: '10', text: 'Message 13' }, + { messageId: '14', parentMessageId: '10', text: 'Message 14' }, + { messageId: '15', parentMessageId: '4', text: 'Message 15' }, + { messageId: '16', parentMessageId: '15', text: 'Message 15' }, + ]; + const result = splitAtTargetLevel(mockMessages, '11'); + const mappedResult = result.map((msg) => msg.messageId); + console.debug( + 'splitAtTargetLevel: should handle when target level is last with many neighbors\n', + printMessageTree(result), + ); + expect(mappedResult).toEqual(['11', '12', '13', '14', '16']); + }); + + test('should handle non-existent target message', () => { + // Non-existent message ID + const result = splitAtTargetLevel(mockMessagesComplex, '99'); + expect(result.length).toBe(0); + }); +}); diff --git a/api/server/utils/import/importBatchBuilder.js b/api/server/utils/import/importBatchBuilder.js index 493559afa36..16b4f0ffdaf 100644 --- a/api/server/utils/import/importBatchBuilder.js +++ b/api/server/utils/import/importBatchBuilder.js @@ -70,10 +70,12 @@ class ImportBatchBuilder { * Finishes the current conversation and adds it to the batch. * @param {string} [title='Imported Chat'] - The title of the conversation. Defaults to 'Imported Chat'. * @param {Date} [createdAt] - The creation date of the conversation. - * @returns {object} The added conversation object. + * @param {TConversation} [originalConvo] - The original conversation. + * @returns {{ conversation: TConversation, messages: TMessage[] }} The resulting conversation and messages. */ - finishConversation(title, createdAt) { + finishConversation(title, createdAt, originalConvo = {}) { const convo = { + ...originalConvo, user: this.requestUserId, conversationId: this.conversationId, title: title || 'Imported Chat', @@ -81,11 +83,12 @@ class ImportBatchBuilder { updatedAt: createdAt, overrideTimestamp: true, endpoint: this.endpoint, - model: openAISettings.model.default, + model: originalConvo.model ?? openAISettings.model.default, }; + convo._id && delete convo._id; this.conversations.push(convo); - return convo; + return { conversation: convo, messages: this.messages }; } /** @@ -114,7 +117,9 @@ class ImportBatchBuilder { * @param {string} [messageDetails.messageId] - The ID of the current message. * @param {boolean} messageDetails.isCreatedByUser - Indicates whether the message is created by the user. * @param {string} [messageDetails.model] - The model used for generating the message. + * @param {string} [messageDetails.endpoint] - The endpoint used for generating the message. * @param {string} [messageDetails.parentMessageId=this.lastMessageId] - The ID of the parent message. + * @param {Partial} messageDetails.rest - Additional properties that may be included in the message. * @returns {object} The saved message object. */ saveMessage({ @@ -124,22 +129,26 @@ class ImportBatchBuilder { model, messageId, parentMessageId = this.lastMessageId, + endpoint, + ...rest }) { const newMessageId = messageId ?? uuidv4(); const message = { + ...rest, parentMessageId, messageId: newMessageId, conversationId: this.conversationId, isCreatedByUser: isCreatedByUser, model: model || this.model, user: this.requestUserId, - endpoint: this.endpoint, + endpoint: endpoint ?? this.endpoint, unfinished: false, isEdited: false, error: false, sender, text, }; + message._id && delete message._id; this.lastMessageId = newMessageId; this.messages.push(message); return message; diff --git a/api/server/utils/import/importers.js b/api/server/utils/import/importers.js index cd901d8e8bc..f1762a988eb 100644 --- a/api/server/utils/import/importers.js +++ b/api/server/utils/import/importers.js @@ -48,7 +48,7 @@ async function importChatBotUiConvo( ) { // this have been tested with chatbot-ui V1 export https://github.com/mckaywrigley/chatbot-ui/tree/b865b0555f53957e96727bc0bbb369c9eaecd83b#legacy-code try { - /** @type {import('./importBatchBuilder').ImportBatchBuilder} */ + /** @type {ImportBatchBuilder} */ const importBatchBuilder = builderFactory(requestUserId); for (const historyItem of jsonData.history) { @@ -83,7 +83,7 @@ async function importLibreChatConvo( builderFactory = createImportBatchBuilder, ) { try { - /** @type {import('./importBatchBuilder').ImportBatchBuilder} */ + /** @type {ImportBatchBuilder} */ const importBatchBuilder = builderFactory(requestUserId); importBatchBuilder.startConversation(EModelEndpoint.openAI); @@ -163,7 +163,7 @@ async function importChatGptConvo( * It directly manages the addition of messages for different roles and handles citations for assistant messages. * * @param {ChatGPTConvo} conv - A single conversation object that contains multiple messages and other details. - * @param {import('./importBatchBuilder').ImportBatchBuilder} importBatchBuilder - The batch builder instance used to manage and batch conversation data. + * @param {ImportBatchBuilder} importBatchBuilder - The batch builder instance used to manage and batch conversation data. * @param {string} requestUserId - The ID of the user who initiated the import process. * @returns {void} */ diff --git a/api/typedefs.js b/api/typedefs.js index 44dcf784f14..df5e8be2bef 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -644,6 +644,12 @@ * @memberof typedefs */ +/** + * @exports ImportBatchBuilder + * @typedef {import('./server/utils/import/importBatchBuilder.js').ImportBatchBuilder} ImportBatchBuilder + * @memberof typedefs + */ + /** * @exports Thread * @typedef {Object} Thread @@ -1257,3 +1263,17 @@ * @property {Object.} mapping - Mapping of message nodes within the conversation. * @memberof typedefs */ + +/** Mutations */ + +/** + * @exports TForkConvoResponse + * @typedef {import('librechat-data-provider').TForkConvoResponse} TForkConvoResponse + * @memberof typedefs + */ + +/** + * @exports TForkConvoRequest + * @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest + * @memberof typedefs + */ diff --git a/client/src/components/Chat/Input/GenerationButtons.tsx b/client/src/components/Chat/Input/GenerationButtons.tsx deleted file mode 100644 index d3e74c84a6a..00000000000 --- a/client/src/components/Chat/Input/GenerationButtons.tsx +++ /dev/null @@ -1,86 +0,0 @@ -import { useEffect, useState } from 'react'; -import type { TMessage } from 'librechat-data-provider'; -import { useMediaQuery, useGenerationsByLatest } from '~/hooks'; -import Regenerate from '~/components/Input/Generations/Regenerate'; -import Continue from '~/components/Input/Generations/Continue'; -import Stop from '~/components/Input/Generations/Stop'; -import { useChatContext } from '~/Providers'; -import { cn } from '~/utils'; - -type GenerationButtonsProps = { - endpoint: string; - showPopover?: boolean; - opacityClass?: string; -}; - -export default function GenerationButtons({ - endpoint, - showPopover = false, - opacityClass = 'full-opacity', -}: GenerationButtonsProps) { - const { - getMessages, - isSubmitting, - latestMessage, - handleContinue, - handleRegenerate, - handleStopGenerating, - } = useChatContext(); - const isSmallScreen = useMediaQuery('(max-width: 768px)'); - const { continueSupported, regenerateEnabled } = useGenerationsByLatest({ - endpoint, - message: latestMessage as TMessage, - isSubmitting, - latestMessage, - }); - - const [userStopped, setUserStopped] = useState(false); - const messages = getMessages(); - - const handleStop = (e: React.MouseEvent) => { - setUserStopped(true); - handleStopGenerating(e); - }; - - useEffect(() => { - let timer: NodeJS.Timeout; - - if (userStopped) { - timer = setTimeout(() => { - setUserStopped(false); - }, 200); - } - - return () => { - clearTimeout(timer); - }; - }, [userStopped]); - - if (isSmallScreen) { - return null; - } - - let button: React.ReactNode = null; - - if (isSubmitting) { - button = ; - } else if (userStopped || continueSupported) { - button = ; - } else if (messages && messages.length > 0 && regenerateEnabled) { - button = ; - } - - return ( -
-
-
-
- {button} -
-
-
- ); -} diff --git a/client/src/components/Chat/Messages/HoverButtons.tsx b/client/src/components/Chat/Messages/HoverButtons.tsx index 7f6b288cbbe..63da4ecaa0e 100644 --- a/client/src/components/Chat/Messages/HoverButtons.tsx +++ b/client/src/components/Chat/Messages/HoverButtons.tsx @@ -3,6 +3,7 @@ import { EModelEndpoint } from 'librechat-data-provider'; import type { TConversation, TMessage } from 'librechat-data-provider'; import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg'; import { useGenerationsByLatest, useLocalize } from '~/hooks'; +import { Fork } from '~/components/Conversations'; import { cn } from '~/utils'; type THoverButtons = { @@ -34,13 +35,14 @@ export default function HoverButtons({ const { endpoint: _endpoint, endpointType } = conversation ?? {}; const endpoint = endpointType ?? _endpoint; const [isCopied, setIsCopied] = useState(false); - const { hideEditButton, regenerateEnabled, continueSupported } = useGenerationsByLatest({ - isEditing, - isSubmitting, - message, - endpoint: endpoint ?? '', - latestMessage, - }); + const { hideEditButton, regenerateEnabled, continueSupported, forkingSupported } = + useGenerationsByLatest({ + isEditing, + isSubmitting, + message, + endpoint: endpoint ?? '', + latestMessage, + }); if (!conversation) { return null; } @@ -113,6 +115,13 @@ export default function HoverButtons({ ) : null} +
); } diff --git a/client/src/components/Conversations/Fork.tsx b/client/src/components/Conversations/Fork.tsx new file mode 100644 index 00000000000..d4e53e86e69 --- /dev/null +++ b/client/src/components/Conversations/Fork.tsx @@ -0,0 +1,331 @@ +import { useState, useRef } from 'react'; +import { useRecoilState } from 'recoil'; +import { GitFork, InfoIcon } from 'lucide-react'; +import * as Popover from '@radix-ui/react-popover'; +import { ForkOptions, TMessage } from 'librechat-data-provider'; +import { GitCommit, GitBranchPlus, ListTree } from 'lucide-react'; +import { + Checkbox, + HoverCard, + HoverCardTrigger, + HoverCardPortal, + HoverCardContent, +} from '~/components/ui'; +import OptionHover from '~/components/SidePanel/Parameters/OptionHover'; +import { useToastContext, useChatContext } from '~/Providers'; +import { useLocalize, useNavigateToConvo } from '~/hooks'; +import { useForkConvoMutation } from '~/data-provider'; +import { ESide } from '~/common'; +import { cn } from '~/utils'; +import store from '~/store'; + +interface PopoverButtonProps { + children: React.ReactNode; + setting: string; + onClick: (setting: string) => void; + setActiveSetting: React.Dispatch>; + sideOffset?: number; + timeoutRef: React.MutableRefObject; + hoverInfo?: React.ReactNode; + hoverTitle?: React.ReactNode; + hoverDescription?: React.ReactNode; +} + +const optionLabels = { + [ForkOptions.DIRECT_PATH]: 'com_ui_fork_visible', + [ForkOptions.INCLUDE_BRANCHES]: 'com_ui_fork_branches', + [ForkOptions.TARGET_LEVEL]: 'com_ui_fork_all_target', + default: 'com_ui_fork_from_message', +}; + +const PopoverButton: React.FC = ({ + children, + setting, + onClick, + setActiveSetting, + sideOffset = 30, + timeoutRef, + hoverInfo, + hoverTitle, + hoverDescription, +}) => { + return ( + + onClick(setting)} + onMouseEnter={() => { + if (timeoutRef.current) { + clearTimeout(timeoutRef.current); + timeoutRef.current = null; + } + setActiveSetting(optionLabels[setting]); + }} + onMouseLeave={() => { + if (timeoutRef.current) { + clearTimeout(timeoutRef.current); + } + timeoutRef.current = setTimeout(() => { + setActiveSetting(optionLabels.default); + }, 175); + }} + className="mx-1 max-w-14 flex-1 rounded-lg border-2 bg-white transition duration-300 ease-in-out hover:bg-black dark:border-gray-400 dark:bg-gray-700/95 dark:text-gray-400 hover:dark:border-gray-200 hover:dark:text-gray-200" + type="button" + > + {children} + + {(hoverInfo || hoverTitle || hoverDescription) && ( + + +
+

+ {hoverInfo && hoverInfo} + {hoverTitle && {hoverTitle}} + {hoverDescription && hoverDescription} +

+
+
+
+ )} +
+ ); +}; + +export default function Fork({ + isLast, + messageId, + conversationId, + forkingSupported, + latestMessage, +}: { + isLast?: boolean; + messageId: string; + conversationId: string | null; + forkingSupported?: boolean; + latestMessage: TMessage | null; +}) { + const localize = useLocalize(); + const { index } = useChatContext(); + const { showToast } = useToastContext(); + const [remember, setRemember] = useState(false); + const { navigateToConvo } = useNavigateToConvo(index); + const timeoutRef = useRef(null); + const [forkSetting, setForkSetting] = useRecoilState(store.forkSetting); + const [activeSetting, setActiveSetting] = useState(optionLabels.default); + const [splitAtTarget, setSplitAtTarget] = useRecoilState(store.splitAtTarget); + const [rememberGlobal, setRememberGlobal] = useRecoilState(store.rememberForkOption); + const forkConvo = useForkConvoMutation({ + onSuccess: (data) => { + if (data) { + navigateToConvo(data.conversation); + showToast({ + message: localize('com_ui_fork_success'), + status: 'success', + }); + } + }, + onMutate: () => { + showToast({ + message: localize('com_ui_fork_processing'), + status: 'info', + }); + }, + onError: () => { + showToast({ + message: localize('com_ui_fork_error'), + status: 'error', + }); + }, + }); + + if (!forkingSupported || !conversationId || !messageId) { + return null; + } + + const onClick = (option: string) => { + if (remember) { + setRememberGlobal(true); + setForkSetting(option); + } + + forkConvo.mutate({ + messageId, + conversationId, + option, + splitAtTarget, + latestMessageId: latestMessage?.messageId, + }); + }; + + return ( + + + + + +
+ +
+ {localize(activeSetting)} + + + + + + +
+ {localize('com_ui_fork_info_1')} + {localize('com_ui_fork_info_2')} + + {localize('com_ui_fork_info_3', localize('com_ui_fork_split_target'))} + +
+
+
+
+
+
+ + + {localize(optionLabels[ForkOptions.DIRECT_PATH])} + + } + hoverDescription={localize('com_ui_fork_info_visible')} + > + + + + + + + {localize(optionLabels[ForkOptions.INCLUDE_BRANCHES])} + + } + hoverDescription={localize('com_ui_fork_info_branches')} + > + + + + + + + {`${localize(optionLabels[ForkOptions.TARGET_LEVEL])} (${localize( + 'com_endpoint_default', + )})`} + + } + hoverDescription={localize('com_ui_fork_info_target')} + > + + + + +
+ + +
+ setSplitAtTarget(checked)} + className="m-2 transition duration-300 ease-in-out" + /> + {localize('com_ui_fork_split_target')} +
+
+ +
+ + +
+ { + if (checked) { + showToast({ + message: localize('com_ui_fork_remember_checked'), + status: 'info', + }); + } + setRemember(checked); + }} + className="m-2 transition duration-300 ease-in-out" + /> + {localize('com_ui_fork_remember')} +
+
+ +
+
+
+
+
+ ); +} diff --git a/client/src/components/Conversations/index.ts b/client/src/components/Conversations/index.ts index 17acb3e438c..72e8babd449 100644 --- a/client/src/components/Conversations/index.ts +++ b/client/src/components/Conversations/index.ts @@ -1,3 +1,4 @@ +export { default as Fork } from './Fork'; export { default as Pages } from './Pages'; export { default as Conversation } from './Conversation'; export { default as RenameButton } from './RenameButton'; diff --git a/client/src/components/Input/GenerationButtons.tsx b/client/src/components/Input/GenerationButtons.tsx deleted file mode 100644 index 71479febe37..00000000000 --- a/client/src/components/Input/GenerationButtons.tsx +++ /dev/null @@ -1,48 +0,0 @@ -// eslint-disable-next-line @typescript-eslint/no-unused-vars -import { cn, removeFocusOutlines } from '~/utils/'; - -type GenerationButtonsProps = { - showPopover: boolean; - opacityClass: string; -}; - -export default function GenerationButtons({ showPopover, opacityClass }: GenerationButtonsProps) { - return ( -
-
-
-
- {/* */} -
-
-
- ); -} diff --git a/client/src/components/Messages/HoverButtons.tsx b/client/src/components/Messages/HoverButtons.tsx deleted file mode 100644 index 6841b471a79..00000000000 --- a/client/src/components/Messages/HoverButtons.tsx +++ /dev/null @@ -1,101 +0,0 @@ -import { useState } from 'react'; -import type { TConversation, TMessage } from 'librechat-data-provider'; -import { Clipboard, CheckMark, EditIcon, RegenerateIcon, ContinueIcon } from '~/components/svg'; -import { useGenerations, useLocalize } from '~/hooks'; -import { cn } from '~/utils'; - -type THoverButtons = { - isEditing: boolean; - enterEdit: (cancel?: boolean) => void; - copyToClipboard: (setIsCopied: React.Dispatch>) => void; - conversation: TConversation | null; - isSubmitting: boolean; - message: TMessage; - regenerate: () => void; - handleContinue: (e: React.MouseEvent) => void; -}; - -export default function HoverButtons({ - isEditing, - enterEdit, - copyToClipboard, - conversation, - isSubmitting, - message, - regenerate, - handleContinue, -}: THoverButtons) { - const localize = useLocalize(); - const { endpoint } = conversation ?? {}; - const [isCopied, setIsCopied] = useState(false); - const { hideEditButton, regenerateEnabled, continueSupported } = useGenerations({ - isEditing, - isSubmitting, - message, - endpoint: endpoint ?? '', - }); - if (!conversation) { - return null; - } - - const { isCreatedByUser } = message; - - const onEdit = () => { - if (isEditing) { - return enterEdit(true); - } - enterEdit(); - }; - - return ( -
- - - {regenerateEnabled ? ( - - ) : null} - {continueSupported ? ( - - ) : null} -
- ); -} diff --git a/client/src/components/Nav/NavLinks.tsx b/client/src/components/Nav/NavLinks.tsx index ffd7c8bb3bd..964c9193146 100644 --- a/client/src/components/Nav/NavLinks.tsx +++ b/client/src/components/Nav/NavLinks.tsx @@ -79,11 +79,11 @@ function NavLinks() {
diff --git a/client/src/components/Nav/Settings.tsx b/client/src/components/Nav/Settings.tsx index 83ec09fb224..f83810667db 100644 --- a/client/src/components/Nav/Settings.tsx +++ b/client/src/components/Nav/Settings.tsx @@ -1,9 +1,10 @@ import * as Tabs from '@radix-ui/react-tabs'; +import { MessageSquare } from 'lucide-react'; import { SettingsTabValues } from 'librechat-data-provider'; import type { TDialogProps } from '~/common'; import { Dialog, DialogContent, DialogHeader, DialogTitle } from '~/components/ui'; import { GearIcon, DataIcon, UserIcon, ExperimentIcon } from '~/components/svg'; -import { General, Beta, Data, Account } from './SettingsTabs'; +import { General, Messages, Beta, Data, Account } from './SettingsTabs'; import { useMediaQuery, useLocalize } from '~/hooks'; import { cn } from '~/utils'; @@ -54,6 +55,20 @@ export default function Settings({ open, onOpenChange }: TDialogProps) { {localize('com_nav_setting_general')} + + + {localize('com_endpoint_messages')} + + diff --git a/client/src/components/Nav/SettingsTabs/DangerButton.tsx b/client/src/components/Nav/SettingsTabs/DangerButton.tsx index dcec128c95b..d43fd9a5862 100644 --- a/client/src/components/Nav/SettingsTabs/DangerButton.tsx +++ b/client/src/components/Nav/SettingsTabs/DangerButton.tsx @@ -40,7 +40,7 @@ const DangerButton = (props: TDangerButtonProps, ref: ForwardedRef diff --git a/client/src/components/Nav/SettingsTabs/Data/ClearChats.tsx b/client/src/components/Nav/SettingsTabs/Data/ClearChats.tsx new file mode 100644 index 00000000000..4fc2dadd8e0 --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Data/ClearChats.tsx @@ -0,0 +1,29 @@ +import type { TDangerButtonProps } from '~/common'; +import DangerButton from '../DangerButton'; + +export const ClearChatsButton = ({ + confirmClear, + className = '', + showText = true, + mutation, + onClick, +}: Pick< + TDangerButtonProps, + 'confirmClear' | 'mutation' | 'className' | 'showText' | 'onClick' +>) => { + return ( + + ); +}; diff --git a/client/src/components/Nav/SettingsTabs/General/ClearChatsButton.spec.tsx b/client/src/components/Nav/SettingsTabs/Data/ClearChatsButton.spec.tsx similarity index 96% rename from client/src/components/Nav/SettingsTabs/General/ClearChatsButton.spec.tsx rename to client/src/components/Nav/SettingsTabs/Data/ClearChatsButton.spec.tsx index 14a2edfd918..fbf82fd9020 100644 --- a/client/src/components/Nav/SettingsTabs/General/ClearChatsButton.spec.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/ClearChatsButton.spec.tsx @@ -2,7 +2,7 @@ import 'test/matchMedia.mock'; import React from 'react'; import { render, fireEvent } from '@testing-library/react'; import '@testing-library/jest-dom/extend-expect'; -import { ClearChatsButton } from './General'; +import { ClearChatsButton } from './ClearChats'; import { RecoilRoot } from 'recoil'; describe('ClearChatsButton', () => { diff --git a/client/src/components/Nav/SettingsTabs/Data/Data.tsx b/client/src/components/Nav/SettingsTabs/Data/Data.tsx index f41642e634e..69e4e4887bc 100644 --- a/client/src/components/Nav/SettingsTabs/Data/Data.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/Data.tsx @@ -1,13 +1,15 @@ import * as Tabs from '@radix-ui/react-tabs'; import { - useRevokeAllUserKeysMutation, useRevokeUserKeyMutation, + useRevokeAllUserKeysMutation, + useClearConversationsMutation, } from 'librechat-data-provider/react-query'; import { SettingsTabValues } from 'librechat-data-provider'; import React, { useState, useCallback, useRef } from 'react'; -import { useOnClickOutside } from '~/hooks'; -import DangerButton from '../DangerButton'; +import { useConversation, useConversations, useOnClickOutside } from '~/hooks'; import ImportConversations from './ImportConversations'; +import { ClearChatsButton } from './ClearChats'; +import DangerButton from '../DangerButton'; export const RevokeKeysButton = ({ showText = true, @@ -20,42 +22,43 @@ export const RevokeKeysButton = ({ all?: boolean; disabled?: boolean; }) => { - const [confirmClear, setConfirmClear] = useState(false); - const revokeKeyMutation = useRevokeUserKeyMutation(endpoint); + const [confirmRevoke, setConfirmRevoke] = useState(false); + const revokeKeysMutation = useRevokeAllUserKeysMutation(); + const revokeKeyMutation = useRevokeUserKeyMutation(endpoint); - const contentRef = useRef(null); - useOnClickOutside(contentRef, () => confirmClear && setConfirmClear(false), []); + const revokeContentRef = useRef(null); + useOnClickOutside(revokeContentRef, () => confirmRevoke && setConfirmRevoke(false), []); const revokeAllUserKeys = useCallback(() => { - if (confirmClear) { + if (confirmRevoke) { revokeKeysMutation.mutate({}); - setConfirmClear(false); + setConfirmRevoke(false); } else { - setConfirmClear(true); + setConfirmRevoke(true); } - }, [confirmClear, revokeKeysMutation]); + }, [confirmRevoke, revokeKeysMutation]); const revokeUserKey = useCallback(() => { if (!endpoint) { return; - } else if (confirmClear) { + } else if (confirmRevoke) { revokeKeyMutation.mutate({}); - setConfirmClear(false); + setConfirmRevoke(false); } else { - setConfirmClear(true); + setConfirmRevoke(true); } - }, [confirmClear, revokeKeyMutation, endpoint]); + }, [confirmRevoke, revokeKeyMutation, endpoint]); const onClick = all ? revokeAllUserKeys : revokeUserKey; return ( confirmClearConvos && setConfirmClearConvos(false), []); + + const { newConversation } = useConversation(); + const { refreshConversations } = useConversations(); + const clearConvosMutation = useClearConversationsMutation(); + + const clearConvos = () => { + if (confirmClearConvos) { + console.log('Clearing conversations...'); + setConfirmClearConvos(false); + clearConvosMutation.mutate( + {}, + { + onSuccess: () => { + newConversation(); + refreshConversations(); + }, + }, + ); + } else { + setConfirmClearConvos(true); + } + }; + return (
+
+ +
+
- +
diff --git a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx index af6e21aa7f9..ae781dfd810 100644 --- a/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx +++ b/client/src/components/Nav/SettingsTabs/Data/ImportConversations.tsx @@ -1,15 +1,17 @@ +import { useState } from 'react'; import { Import } from 'lucide-react'; -import { cn } from '~/utils'; import { useUploadConversationsMutation } from '~/data-provider'; import { useLocalize, useConversations } from '~/hooks'; -import { useState } from 'react'; import { useToastContext } from '~/Providers'; +import { Spinner } from '~/components/svg'; +import { cn } from '~/utils'; function ImportConversations() { const localize = useLocalize(); const { showToast } = useToastContext(); const [, setErrors] = useState([]); + const [allowImport, setAllowImport] = useState(true); const setError = (error: string) => setErrors((prevErrors) => [...prevErrors, error]); const { refreshConversations } = useConversations(); @@ -17,9 +19,11 @@ function ImportConversations() { onSuccess: () => { refreshConversations(); showToast({ message: localize('com_ui_import_conversation_success') }); + setAllowImport(true); }, onError: (error) => { console.error('Error: ', error); + setAllowImport(true); setError( (error as { response: { data: { message?: string } } })?.response?.data?.message ?? 'An error occurred while uploading the file.', @@ -33,6 +37,9 @@ function ImportConversations() { showToast({ message: localize('com_ui_import_conversation_error'), status: 'error' }); } }, + onMutate: () => { + setAllowImport(false); + }, }); const startUpload = async (file: File) => { @@ -43,8 +50,6 @@ function ImportConversations() { }; const handleFiles = async (_file: File) => { - console.log('Handling files...'); - /* Process files */ try { await startUpload(_file); @@ -55,7 +60,6 @@ function ImportConversations() { }; const handleFileChange = (event) => { - console.log('file change'); const file = event.target.files[0]; if (file) { handleFiles(file); @@ -67,12 +71,17 @@ function ImportConversations() { {localize('com_ui_import_conversation_info')}