Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🌿 feat: Fork Messages/Conversations #2617

Merged
merged 35 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b06be8e
typedef for ImportBatchBuilder
danny-avila May 2, 2024
9cd9caa
feat: first pass, fork conversations
danny-avila May 2, 2024
c40b53e
feat: fork - getMessagesUpToTargetLevel
danny-avila May 2, 2024
8e71f98
fix: additional tests and fix getAllMessagesUpToParent
danny-avila May 2, 2024
6bbb9c7
chore: arrow function return
danny-avila May 2, 2024
523d321
refactor: fork 3 options
danny-avila May 2, 2024
eccc8c5
chore: remove unused genbuttons
danny-avila May 2, 2024
00f315a
chore: remove unused hover buttons code
danny-avila May 2, 2024
0f6a928
feat: fork first pass
danny-avila May 3, 2024
106f94c
wip: fork remember setting
danny-avila May 3, 2024
ea23462
style: user icon
danny-avila May 3, 2024
111ba6c
chore: move clear chats to data tab
danny-avila May 3, 2024
0631725
WIP: fork UI options
danny-avila May 3, 2024
9c966f9
feat: data-provider fork types/services/vars and use generic Mutation…
danny-avila May 4, 2024
41eb736
refactor: use single param for fork option, use enum, fix mongo error…
danny-avila May 4, 2024
73fc893
feat: add fork mutation hook and consolidate type imports
danny-avila May 4, 2024
e2525de
refactor: use enum
danny-avila May 4, 2024
bbb03c1
feat: first pass, fork mutation
danny-avila May 4, 2024
e85fc2f
chore: add enum for target level fork option
danny-avila May 4, 2024
e3c4cb2
chore: add enum for target level fork option
danny-avila May 4, 2024
b91d5f9
show toast when checking remember selection
danny-avila May 4, 2024
ca1b477
feat: splitAtTarget
danny-avila May 4, 2024
c6af0e9
feat: split at target option
danny-avila May 4, 2024
df416cc
feat: navigate to new fork, show toasts, set result query data
danny-avila May 4, 2024
b4aa1cb
feat: hover info for all fork options
danny-avila May 5, 2024
eb050b1
refactor: add Messages settings tab
danny-avila May 5, 2024
8f834c1
fix(Fork): remember text info
danny-avila May 5, 2024
ff90cde
ci: test for single message and is target edge case
danny-avila May 5, 2024
deafee9
feat: additional tests for getAllMessagesUpToParent
danny-avila May 5, 2024
8a794a0
ci: additional tests and cycle detection for getMessagesUpToTargetLevel
danny-avila May 5, 2024
c3431f6
feat: circular dependency checks for getAllMessagesUpToParent
danny-avila May 5, 2024
0e7dcb6
fix: getMessagesUpToTargetLevel circular dep. check
danny-avila May 5, 2024
0a5df54
ci: more tests for getMessagesForConversation
danny-avila May 5, 2024
2745e66
style: hover text for checkbox fork items
danny-avila May 5, 2024
ae20082
refactor: add statefulness to conversation import
danny-avila May 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,11 @@ class BaseClient {
* the message is considered a root message.
*
* @param {Object} options - The options for the function.
* @param {Array} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property.
* @param {TMessage[]} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property.
* @param {string} options.parentMessageId - The ID of the parent message to start the traversal from.
* @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array.
* @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property.
* @returns {Array} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'.
* @returns {TMessage[]} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'.
*/
static getMessagesForConversation({
messages,
Expand Down
6 changes: 6 additions & 0 deletions api/models/Conversation.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ const Conversation = require('./schema/convoSchema');
const { getMessages, deleteMessages } = require('./Message');
const logger = require('~/config/winston');

/**
* Retrieves a single conversation for a given user and conversation ID.
* @param {string} user - The user's ID.
* @param {string} conversationId - The conversation's ID.
* @returns {Promise<TConversation>} The conversation object.
*/
const getConvo = async (user, conversationId) => {
try {
return await Conversation.findOne({ user, conversationId }).lean();
Expand Down
30 changes: 30 additions & 0 deletions api/server/routes/convos.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models
const { IMPORT_CONVERSATION_JOB_NAME } = require('~/server/utils/import/jobDefinition');
const { storage, importFileFilter } = require('~/server/routes/files/multer');
const requireJwtAuth = require('~/server/middleware/requireJwtAuth');
const { forkConversation } = require('~/server/utils/import/fork');
const { createImportLimiters } = require('~/server/middleware');
const jobScheduler = require('~/server/utils/jobScheduler');
const getLogStores = require('~/cache/getLogStores');
Expand Down Expand Up @@ -131,6 +132,35 @@ router.post(
},
);

/**
* POST /fork
* This route handles forking a conversation based on the TForkConvoRequest and responds with TForkConvoResponse.
* @route POST /fork
* @param {express.Request<{}, TForkConvoResponse, TForkConvoRequest>} req - Express request object.
* @param {express.Response<TForkConvoResponse>} res - Express response object.
* @returns {Promise<void>} - The response after forking the conversation.
*/
router.post('/fork', async (req, res) => {
try {
/** @type {TForkConvoRequest} */
const { conversationId, messageId, option, splitAtTarget, latestMessageId } = req.body;
const result = await forkConversation({
requestUserId: req.user.id,
originalConvoId: conversationId,
targetMessageId: messageId,
latestMessageId,
records: true,
splitAtTarget,
option,
});

res.json(result);
} catch (error) {
logger.error('Error forking conversation', error);
res.status(500).send('Error forking conversation');
}
});

// Get the status of an import job for polling
router.get('/import/jobs/:jobId', async (req, res) => {
try {
Expand Down
314 changes: 314 additions & 0 deletions api/server/utils/import/fork.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
const { v4: uuidv4 } = require('uuid');
const { EModelEndpoint, Constants, ForkOptions } = require('librechat-data-provider');
const { createImportBatchBuilder } = require('./importBatchBuilder');
const BaseClient = require('~/app/clients/BaseClient');
const { getConvo } = require('~/models/Conversation');
const { getMessages } = require('~/models/Message');
const logger = require('~/config/winston');

/**
*
* @param {object} params - The parameters for the importer.
* @param {string} params.originalConvoId - The ID of the conversation to fork.
* @param {string} params.targetMessageId - The ID of the message to fork from.
* @param {string} params.requestUserId - The ID of the user making the request.
* @param {string} [params.newTitle] - Optional new title for the forked conversation uses old title if not provided
* @param {string} [params.option=''] - Optional flag for fork option
* @param {boolean} [params.records=false] - Optional flag for returning actual database records or resulting conversation and messages.
* @param {boolean} [params.splitAtTarget=false] - Optional flag for splitting the messages at the target message level.
* @param {string} [params.latestMessageId] - latestMessageId - Required if splitAtTarget is true.
* @param {(userId: string) => ImportBatchBuilder} [params.builderFactory] - Optional factory function for creating an ImportBatchBuilder instance.
* @returns {Promise<TForkConvoResponse>} The response after forking the conversation.
*/
async function forkConversation({
originalConvoId,
targetMessageId: targetId,
requestUserId,
newTitle,
option = ForkOptions.TARGET_LEVEL,
records = false,
splitAtTarget = false,
latestMessageId,
builderFactory = createImportBatchBuilder,
}) {
try {
const originalConvo = await getConvo(requestUserId, originalConvoId);
let originalMessages = await getMessages({
user: requestUserId,
conversationId: originalConvoId,
});

let targetMessageId = targetId;
if (splitAtTarget && !latestMessageId) {
throw new Error('Latest `messageId` is required for forking from target message.');
} else if (splitAtTarget) {
originalMessages = splitAtTargetLevel(originalMessages, targetId);
targetMessageId = latestMessageId;
}

const importBatchBuilder = builderFactory(requestUserId);
importBatchBuilder.startConversation(originalConvo.endpoint ?? EModelEndpoint.openAI);

let messagesToClone = [];

if (option === ForkOptions.DIRECT_PATH) {
// Direct path only
messagesToClone = BaseClient.getMessagesForConversation({
messages: originalMessages,
parentMessageId: targetMessageId,
});
} else if (option === ForkOptions.INCLUDE_BRANCHES) {
// Direct path and siblings
messagesToClone = getAllMessagesUpToParent(originalMessages, targetMessageId);
} else if (option === ForkOptions.TARGET_LEVEL || !option) {
// Direct path, siblings, and all descendants
messagesToClone = getMessagesUpToTargetLevel(originalMessages, targetMessageId);
}

const idMapping = new Map();

for (const message of messagesToClone) {
const newMessageId = uuidv4();
idMapping.set(message.messageId, newMessageId);

const clonedMessage = {
...message,
messageId: newMessageId,
parentMessageId:
message.parentMessageId && message.parentMessageId !== Constants.NO_PARENT
? idMapping.get(message.parentMessageId)
: Constants.NO_PARENT,
};

importBatchBuilder.saveMessage(clonedMessage);
}

const result = importBatchBuilder.finishConversation(
newTitle || originalConvo.title,
new Date(),
originalConvo,
);
await importBatchBuilder.saveBatch();
logger.debug(
`user: ${requestUserId} | New conversation "${
newTitle || originalConvo.title
}" forked from conversation ID ${originalConvoId}`,
);

if (!records) {
return result;
}

const conversation = await getConvo(requestUserId, result.conversation.conversationId);
const messages = await getMessages({
user: requestUserId,
conversationId: conversation.conversationId,
});

return {
conversation,
messages,
};
} catch (error) {
logger.error(
`user: ${requestUserId} | Error forking conversation from original ID ${originalConvoId}`,
error,
);
throw error;
}
}

/**
* Retrieves all messages up to the root from the target message.
* @param {TMessage[]} messages - The list of messages to search.
* @param {string} targetMessageId - The ID of the target message.
* @returns {TMessage[]} The list of messages up to the root from the target message.
*/
function getAllMessagesUpToParent(messages, targetMessageId) {
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
if (!targetMessage) {
return [];
}

const pathToRoot = new Set();
const visited = new Set();
let current = targetMessage;

while (current) {
if (visited.has(current.messageId)) {
break;
}

visited.add(current.messageId);
pathToRoot.add(current.messageId);

const currentParentId = current.parentMessageId ?? Constants.NO_PARENT;
if (currentParentId === Constants.NO_PARENT) {
break;
}

current = messages.find((msg) => msg.messageId === currentParentId);
}

// Include all messages that are in the path or whose parent is in the path
// Exclude children of the target message
return messages.filter(
(msg) =>
(pathToRoot.has(msg.messageId) && msg.messageId !== targetMessageId) ||
(pathToRoot.has(msg.parentMessageId) && msg.parentMessageId !== targetMessageId) ||
msg.messageId === targetMessageId,
);
}

/**
* Retrieves all messages up to the root from the target message and its neighbors.
* @param {TMessage[]} messages - The list of messages to search.
* @param {string} targetMessageId - The ID of the target message.
* @returns {TMessage[]} The list of inclusive messages up to the root from the target message.
*/
function getMessagesUpToTargetLevel(messages, targetMessageId) {
if (messages.length === 1 && messages[0] && messages[0].messageId === targetMessageId) {
return messages;
}

// Create a map of parentMessageId to children messages
const parentToChildrenMap = new Map();
for (const message of messages) {
if (!parentToChildrenMap.has(message.parentMessageId)) {
parentToChildrenMap.set(message.parentMessageId, []);
}
parentToChildrenMap.get(message.parentMessageId).push(message);
}

// Retrieve the target message
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
if (!targetMessage) {
logger.error('Target message not found.');
return [];
}

const visited = new Set();

const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || [];
let currentLevel = rootMessages.length > 0 ? [...rootMessages] : [targetMessage];
const results = new Set(currentLevel);

// Check if the target message is at the root level
if (
currentLevel.some((msg) => msg.messageId === targetMessageId) &&
targetMessage.parentMessageId === Constants.NO_PARENT
) {
return Array.from(results);
}

// Iterate level by level until the target is found
let targetFound = false;
while (!targetFound && currentLevel.length > 0) {
const nextLevel = [];
for (const node of currentLevel) {
if (visited.has(node.messageId)) {
logger.warn('Cycle detected in message tree');
continue;
}
visited.add(node.messageId);
const children = parentToChildrenMap.get(node.messageId) || [];
for (const child of children) {
if (visited.has(child.messageId)) {
logger.warn('Cycle detected in message tree');
continue;
}
nextLevel.push(child);
results.add(child);
if (child.messageId === targetMessageId) {
targetFound = true;
}
}
}
currentLevel = nextLevel;
}

return Array.from(results);
}

/**
* Splits the conversation at the targeted message level, including the target, its siblings, and all descendant messages.
* All target level messages have their parentMessageId set to the root.
* @param {TMessage[]} messages - The list of messages to analyze.
* @param {string} targetMessageId - The ID of the message to start the split from.
* @returns {TMessage[]} The list of messages at and below the target level.
*/
function splitAtTargetLevel(messages, targetMessageId) {
// Create a map of parentMessageId to children messages
const parentToChildrenMap = new Map();
for (const message of messages) {
if (!parentToChildrenMap.has(message.parentMessageId)) {
parentToChildrenMap.set(message.parentMessageId, []);
}
parentToChildrenMap.get(message.parentMessageId).push(message);
}

// Retrieve the target message
const targetMessage = messages.find((msg) => msg.messageId === targetMessageId);
if (!targetMessage) {
logger.error('Target message not found.');
return [];
}

// Initialize the search with root messages
const rootMessages = parentToChildrenMap.get(Constants.NO_PARENT) || [];
let currentLevel = [...rootMessages];
let currentLevelIndex = 0;
const levelMap = {};

// Map messages to their levels
rootMessages.forEach((msg) => {
levelMap[msg.messageId] = 0;
});

// Search for the target level
while (currentLevel.length > 0) {
const nextLevel = [];
for (const node of currentLevel) {
const children = parentToChildrenMap.get(node.messageId) || [];
for (const child of children) {
nextLevel.push(child);
levelMap[child.messageId] = currentLevelIndex + 1;
}
}
currentLevel = nextLevel;
currentLevelIndex++;
}

// Determine the target level
const targetLevel = levelMap[targetMessageId];
if (targetLevel === undefined) {
logger.error('Target level not found.');
return [];
}

// Filter messages at or below the target level
const filteredMessages = messages
.map((msg) => {
const messageLevel = levelMap[msg.messageId];
if (messageLevel < targetLevel) {
return null;
} else if (messageLevel === targetLevel) {
return {
...msg,
parentMessageId: Constants.NO_PARENT,
};
}

return msg;
})
.filter((msg) => msg !== null);

return filteredMessages;
}

module.exports = {
forkConversation,
splitAtTargetLevel,
getAllMessagesUpToParent,
getMessagesUpToTargetLevel,
};
Loading
Loading