From d3cafeee96d0fc3b2559f9ae1c594b9b76b6acd8 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 17 Dec 2024 22:11:18 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=8D=20feat:=20Add=20Entity=20ID=20Supp?= =?UTF-8?q?ort=20for=20File=20Search=20Shared=20Resources=20(#5028)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/clients/tools/util/fileSearch.js | 46 +++++++++++++--------- api/app/clients/tools/util/handleTools.js | 2 +- api/server/services/Files/VectorDB/crud.js | 8 +++- api/server/services/Files/process.js | 4 +- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 2d1010bd3b5..23ba58bb5a0 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -50,9 +50,10 @@ const primeFiles = async (options) => { * @param {Object} options * @param {ServerRequest} options.req * @param {Array<{ file_id: string; filename: string }>} options.files + * @param {string} [options.entity_id] * @returns */ -const createFileSearchTool = async ({ req, files }) => { +const createFileSearchTool = async ({ req, files, entity_id }) => { return tool( async ({ query }) => { if (files.length === 0) { @@ -62,27 +63,36 @@ const createFileSearchTool = async ({ req, files }) => { if (!jwtToken) { return 'There was an error authenticating the file search request.'; } + + /** + * + * @param {import('librechat-data-provider').TFile} file + * @returns {{ file_id: string, query: string, k: number, entity_id?: string }} + */ + const createQueryBody = (file) => { + const body = { + file_id: file.file_id, + query, + k: 5, + }; + if (!entity_id) { + return body; + } + body.entity_id = entity_id; + logger.debug(`[${Tools.file_search}] RAG API /query body`, body); + return body; + }; + const queryPromises = files.map((file) => axios - .post( - `${process.env.RAG_API_URL}/query`, - { - file_id: file.file_id, - query, - k: 5, - }, - { - headers: { - Authorization: `Bearer ${jwtToken}`, - 'Content-Type': 'application/json', - }, + .post(`${process.env.RAG_API_URL}/query`, createQueryBody(file), { + headers: { + Authorization: `Bearer ${jwtToken}`, + 'Content-Type': 'application/json', }, - ) + }) .catch((error) => { - logger.error( - `Error encountered in \`file_search\` while querying file_id ${file._id}:`, - error, - ); + logger.error('Error encountered in `file_search` while querying file:', error); return null; }), ); diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index a8d0ec13f82..2a26061abf4 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -256,7 +256,7 @@ const loadTools = async ({ if (toolContext) { toolContextMap[tool] = toolContext; } - return createFileSearchTool({ req: options.req, files }); + return createFileSearchTool({ req: options.req, files, entity_id: agent?.id }); }; continue; } else if (mcpToolPattern.test(tool)) { diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index a4d48064d79..d290eea4b1b 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -50,13 +50,14 @@ const deleteVectors = async (req, file) => { * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should * have a `path` property that points to the location of the uploaded file. * @param {string} params.file_id - The file ID. + * @param {string} [params.entity_id] - The entity ID for shared resources. * * @returns {Promise<{ filepath: string, bytes: number }>} * A promise that resolves to an object containing: * - filepath: The path where the file is saved. * - bytes: The size of the file in bytes. */ -async function uploadVectors({ req, file, file_id }) { +async function uploadVectors({ req, file, file_id, entity_id }) { if (!process.env.RAG_API_URL) { throw new Error('RAG_API_URL not defined'); } @@ -66,8 +67,11 @@ async function uploadVectors({ req, file, file_id }) { const formData = new FormData(); formData.append('file_id', file_id); formData.append('file', fs.createReadStream(file.path)); + if (entity_id != null && entity_id) { + formData.append('entity_id', entity_id); + } - const formHeaders = formData.getHeaders(); // Automatically sets the correct Content-Type + const formHeaders = formData.getHeaders(); const response = await axios.post(`${process.env.RAG_API_URL}/embed`, formData, { headers: { diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 198dc940007..709b2a5ce44 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -479,6 +479,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { } let fileInfoMetadata; + const entity_id = messageAttachment === true ? undefined : agent_id; if (tool_resource === EToolResources.execute_code) { const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions(FileSources.execute_code); const result = await loadAuthValues({ userId: req.user.id, authFields: [EnvVar.CODE_API_KEY] }); @@ -488,7 +489,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { stream, filename: file.originalname, apiKey: result[EnvVar.CODE_API_KEY], - entity_id: messageAttachment === true ? undefined : agent_id, + entity_id, }); fileInfoMetadata = { fileIdentifier }; } @@ -512,6 +513,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { req, file, file_id, + entity_id, }); let filepath = _filepath;