From 7346d20224d71d442e2828e99084eee1990a3863 Mon Sep 17 00:00:00 2001 From: Ruben Talstra Date: Sun, 16 Feb 2025 11:56:40 +0100 Subject: [PATCH] refactor: request is encrypted. response from AI is still saved in plaintext but from the stream the final response is encrypted. --- api/server/controllers/AskController.js | 143 ++++++++++++++---------- 1 file changed, 87 insertions(+), 56 deletions(-) diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 04939d3fcc3..83177555d6a 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,16 +1,59 @@ const { getResponseSender, Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage } = require('~/models'); +const { saveMessage, getUserById } = require('~/models'); const { logger } = require('~/config'); let crypto; try { crypto = require('crypto'); } catch (err) { - logger.error('[openidStrategy] crypto support is disabled!', err); + logger.error('[AskController] crypto support is disabled!', err); } +/** + * Helper function to encrypt plaintext using AES-256-GCM and then RSA-encrypt the AES key. + * @param {string} plainText - The plaintext to encrypt. + * @param {string} pemPublicKey - The RSA public key in PEM format. + * @returns {Object} An object containing the ciphertext, iv, authTag, and encryptedKey. + */ +function encryptText(plainText, pemPublicKey) { + // Generate a random 256-bit AES key and a 12-byte IV. + const aesKey = crypto.randomBytes(32); + const iv = crypto.randomBytes(12); + + // Encrypt the plaintext using AES-256-GCM. + const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv); + let ciphertext = cipher.update(plainText, 'utf8', 'base64'); + ciphertext += cipher.final('base64'); + const authTag = cipher.getAuthTag().toString('base64'); + + // Encrypt the AES key using the user's RSA public key. + const encryptedKey = crypto.publicEncrypt( + { + key: pemPublicKey, + padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, + oaepHash: 'sha256', + }, + aesKey, + ).toString('base64'); + + return { + ciphertext, + iv: iv.toString('base64'), + authTag, + encryptedKey, + }; +} + +/** + * AskController + * - Initializes the client. + * - Obtains the response from the language model. + * - Retrieves the full user record (to get encryption parameters). + * - If the user has encryption enabled (i.e. encryptionPublicKey is provided), + * encrypts both the request (userMessage) and the response before saving. + */ const AskController = async (req, res, next, initializeClient, addTitle) => { let { text, @@ -39,7 +82,17 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { modelDisplayLabel, }); const newConvo = !conversationId; - const user = req.user.id; + const userId = req.user.id; // User ID from authentication + + // Retrieve full user record from DB (including encryption parameters) + const dbUser = await getUserById(userId, 'encryptionPublicKey encryptedPrivateKey encryptionSalt encryptionIV'); + + // If the user has provided an encryption public key, rebuild the PEM format. + let pemPublicKey = null; + if (dbUser?.encryptionPublicKey && dbUser.encryptionPublicKey.trim() !== '') { + const pubKeyBase64 = dbUser.encryptionPublicKey; + pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`; + } const getReqData = (data = {}) => { for (let key in data) { @@ -59,11 +112,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { }; let getText; - try { const { client } = await initializeClient({ req, res, endpointOption }); const { onProgress: progressCallback, getPartialText } = createOnProgress(); - getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; const getAbortData = () => ({ @@ -81,14 +132,14 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { res.on('close', () => { logger.debug('[AskController] Request closed'); - if (!abortController) {return;} - if (abortController.signal.aborted || abortController.requestCompleted) {return;} + if (!abortController) { return; } + if (abortController.signal.aborted || abortController.requestCompleted) { return; } abortController.abort(); logger.debug('[AskController] Request aborted on close'); }); const messageOptions = { - user, + user: userId, parentMessageId, conversationId, overrideParentMessageId, @@ -99,10 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { progressOptions: { res }, }; - /** @type {TMessage} */ + // Get the response from the language model client. let response = await client.sendMessage(text, messageOptions); response.endpoint = endpointOption.endpoint; + // Ensure the conversation has a title. const { conversation = {} } = await client.responsePromise; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; @@ -113,54 +165,33 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { delete userMessage.image_urls; } - // --- Encryption Branch --- - // Only encrypt if the user has set up encryption (i.e. non-empty encryptionPublicKey) - if ( - req.user.encryptionPublicKey && - req.user.encryptionPublicKey.trim() !== '' && - response.text && - crypto - ) { + // --- Encrypt the user message if encryption is enabled --- + if (pemPublicKey && userMessage && userMessage.text) { + try { + const { ciphertext, iv, authTag, encryptedKey } = encryptText(userMessage.text, pemPublicKey); + userMessage.text = ciphertext; + userMessage.iv = iv; + userMessage.authTag = authTag; + userMessage.encryptedKey = encryptedKey; + logger.debug('[AskController] User message encrypted.'); + } catch (encError) { + logger.error('[AskController] Error encrypting user message:', encError); + // Optionally, you could choose to throw an error or fallback. + } + } + + // --- Encrypt the AI response if encryption is enabled --- + if (pemPublicKey && response.text) { try { - // Reconstruct the user's RSA public key in PEM format. - const pubKeyBase64 = req.user.encryptionPublicKey; - const pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`; - - // Generate a random 256-bit AES key and a 12-byte IV. - const aesKey = crypto.randomBytes(32); - const iv = crypto.randomBytes(12); - - // Encrypt the response text using AES-GCM. - const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv); - let ciphertext = cipher.update(response.text, 'utf8', 'base64'); - ciphertext += cipher.final('base64'); - const authTag = cipher.getAuthTag().toString('base64'); - - // Encrypt the AES key using the client's RSA public key. - let encryptedKey; - try { - encryptedKey = crypto.publicEncrypt( - { - key: pemPublicKey, - padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, - oaepHash: 'sha256', - }, - aesKey, - ).toString('base64'); - } catch (err) { - logger.error('Error encrypting AES key:', err); - throw new Error('Encryption failure'); - } - - // Replace the plaintext response with the encrypted payload. + const { ciphertext, iv, authTag, encryptedKey } = encryptText(response.text, pemPublicKey); response.text = ciphertext; - response.iv = iv.toString('base64'); + response.iv = iv; response.authTag = authTag; response.encryptedKey = encryptedKey; logger.debug('[AskController] Response message encrypted.'); } catch (encError) { - logger.error('[AskController] Error during response encryption:', encError); - // Optionally, you may choose to return plaintext if encryption fails. + logger.error('[AskController] Error encrypting response message:', encError); + // Optionally, you can choose to send plaintext or handle the error. } } // --- End Encryption Branch --- @@ -178,15 +209,15 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { if (!client.savedMessageIds.has(response.messageId)) { await saveMessage( req, - { ...response, user }, - { context: 'api/server/controllers/AskController.js - response end' }, + { ...response, user: userId }, + { context: 'AskController - response end' }, ); } } if (!client.skipSaveUserMessage) { await saveMessage(req, userMessage, { - context: 'api/server/controllers/AskController.js - don\'t skip saving user message', + context: 'AskController - save user message', }); } @@ -206,9 +237,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { messageId: responseMessageId, parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, }).catch((err) => { - logger.error('[AskController] Error in `handleAbortError`', err); + logger.error('[AskController] Error in handleAbortError', err); }); } }; -module.exports = AskController; +module.exports = AskController; \ No newline at end of file