Skip to content

Commit

Permalink
refactor: request is encrypted. response from AI is still saved in pl…
Browse files Browse the repository at this point in the history
…aintext but from the stream the final response is encrypted.
  • Loading branch information
rubentalstra committed Feb 16, 2025
1 parent 0cc0e5d commit 7346d20
Showing 1 changed file with 87 additions and 56 deletions.
143 changes: 87 additions & 56 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,59 @@
const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { saveMessage, getUserById } = require('~/models');
const { logger } = require('~/config');

let crypto;
try {
crypto = require('crypto');
} catch (err) {
logger.error('[openidStrategy] crypto support is disabled!', err);
logger.error('[AskController] crypto support is disabled!', err);
}

/**
* Helper function to encrypt plaintext using AES-256-GCM and then RSA-encrypt the AES key.
* @param {string} plainText - The plaintext to encrypt.
* @param {string} pemPublicKey - The RSA public key in PEM format.
* @returns {Object} An object containing the ciphertext, iv, authTag, and encryptedKey.
*/
function encryptText(plainText, pemPublicKey) {
// Generate a random 256-bit AES key and a 12-byte IV.
const aesKey = crypto.randomBytes(32);
const iv = crypto.randomBytes(12);

// Encrypt the plaintext using AES-256-GCM.
const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv);
let ciphertext = cipher.update(plainText, 'utf8', 'base64');
ciphertext += cipher.final('base64');
const authTag = cipher.getAuthTag().toString('base64');

// Encrypt the AES key using the user's RSA public key.
const encryptedKey = crypto.publicEncrypt(
{
key: pemPublicKey,
padding: crypto.constants.RSA_PKCS1_OAEP_PADDING,
oaepHash: 'sha256',
},
aesKey,
).toString('base64');

return {
ciphertext,
iv: iv.toString('base64'),
authTag,
encryptedKey,
};
}

/**
* AskController
* - Initializes the client.
* - Obtains the response from the language model.
* - Retrieves the full user record (to get encryption parameters).
* - If the user has encryption enabled (i.e. encryptionPublicKey is provided),
* encrypts both the request (userMessage) and the response before saving.
*/
const AskController = async (req, res, next, initializeClient, addTitle) => {
let {
text,
Expand Down Expand Up @@ -39,7 +82,17 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
modelDisplayLabel,
});
const newConvo = !conversationId;
const user = req.user.id;
const userId = req.user.id; // User ID from authentication

// Retrieve full user record from DB (including encryption parameters)
const dbUser = await getUserById(userId, 'encryptionPublicKey encryptedPrivateKey encryptionSalt encryptionIV');

// If the user has provided an encryption public key, rebuild the PEM format.
let pemPublicKey = null;
if (dbUser?.encryptionPublicKey && dbUser.encryptionPublicKey.trim() !== '') {
const pubKeyBase64 = dbUser.encryptionPublicKey;
pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`;
}

const getReqData = (data = {}) => {
for (let key in data) {
Expand All @@ -59,11 +112,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
};

let getText;

try {
const { client } = await initializeClient({ req, res, endpointOption });
const { onProgress: progressCallback, getPartialText } = createOnProgress();

getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;

const getAbortData = () => ({
Expand All @@ -81,14 +132,14 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {

res.on('close', () => {
logger.debug('[AskController] Request closed');
if (!abortController) {return;}
if (abortController.signal.aborted || abortController.requestCompleted) {return;}
if (!abortController) { return; }
if (abortController.signal.aborted || abortController.requestCompleted) { return; }
abortController.abort();
logger.debug('[AskController] Request aborted on close');
});

const messageOptions = {
user,
user: userId,
parentMessageId,
conversationId,
overrideParentMessageId,
Expand All @@ -99,10 +150,11 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
progressOptions: { res },
};

/** @type {TMessage} */
// Get the response from the language model client.
let response = await client.sendMessage(text, messageOptions);
response.endpoint = endpointOption.endpoint;

// Ensure the conversation has a title.
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
Expand All @@ -113,54 +165,33 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
delete userMessage.image_urls;
}

// --- Encryption Branch ---
// Only encrypt if the user has set up encryption (i.e. non-empty encryptionPublicKey)
if (
req.user.encryptionPublicKey &&
req.user.encryptionPublicKey.trim() !== '' &&
response.text &&
crypto
) {
// --- Encrypt the user message if encryption is enabled ---
if (pemPublicKey && userMessage && userMessage.text) {
try {
const { ciphertext, iv, authTag, encryptedKey } = encryptText(userMessage.text, pemPublicKey);
userMessage.text = ciphertext;
userMessage.iv = iv;
userMessage.authTag = authTag;
userMessage.encryptedKey = encryptedKey;
logger.debug('[AskController] User message encrypted.');
} catch (encError) {
logger.error('[AskController] Error encrypting user message:', encError);
// Optionally, you could choose to throw an error or fallback.
}
}

// --- Encrypt the AI response if encryption is enabled ---
if (pemPublicKey && response.text) {
try {
// Reconstruct the user's RSA public key in PEM format.
const pubKeyBase64 = req.user.encryptionPublicKey;
const pemPublicKey = `-----BEGIN PUBLIC KEY-----\n${pubKeyBase64.match(/.{1,64}/g).join('\n')}\n-----END PUBLIC KEY-----`;

// Generate a random 256-bit AES key and a 12-byte IV.
const aesKey = crypto.randomBytes(32);
const iv = crypto.randomBytes(12);

// Encrypt the response text using AES-GCM.
const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv);
let ciphertext = cipher.update(response.text, 'utf8', 'base64');
ciphertext += cipher.final('base64');
const authTag = cipher.getAuthTag().toString('base64');

// Encrypt the AES key using the client's RSA public key.
let encryptedKey;
try {
encryptedKey = crypto.publicEncrypt(
{
key: pemPublicKey,
padding: crypto.constants.RSA_PKCS1_OAEP_PADDING,
oaepHash: 'sha256',
},
aesKey,
).toString('base64');
} catch (err) {
logger.error('Error encrypting AES key:', err);
throw new Error('Encryption failure');
}

// Replace the plaintext response with the encrypted payload.
const { ciphertext, iv, authTag, encryptedKey } = encryptText(response.text, pemPublicKey);
response.text = ciphertext;
response.iv = iv.toString('base64');
response.iv = iv;
response.authTag = authTag;
response.encryptedKey = encryptedKey;
logger.debug('[AskController] Response message encrypted.');
} catch (encError) {
logger.error('[AskController] Error during response encryption:', encError);
// Optionally, you may choose to return plaintext if encryption fails.
logger.error('[AskController] Error encrypting response message:', encError);
// Optionally, you can choose to send plaintext or handle the error.
}
}
// --- End Encryption Branch ---
Expand All @@ -178,15 +209,15 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (!client.savedMessageIds.has(response.messageId)) {
await saveMessage(
req,
{ ...response, user },
{ context: 'api/server/controllers/AskController.js - response end' },
{ ...response, user: userId },
{ context: 'AskController - response end' },
);
}
}

if (!client.skipSaveUserMessage) {
await saveMessage(req, userMessage, {
context: 'api/server/controllers/AskController.js - don\'t skip saving user message',
context: 'AskController - save user message',
});
}

Expand All @@ -206,9 +237,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId,
}).catch((err) => {
logger.error('[AskController] Error in `handleAbortError`', err);
logger.error('[AskController] Error in handleAbortError', err);
});
}
};

module.exports = AskController;
module.exports = AskController;

0 comments on commit 7346d20

Please sign in to comment.