diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 6cda4471421..6ddcaa97219 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -7,15 +7,12 @@ const { EModelEndpoint, ErrorTypes, Constants, - CacheKeys, - Time, } = require('librechat-data-provider'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { truncateToolCallOutputs } = require('./prompts'); const checkBalance = require('~/models/checkBalance'); const { getFiles } = require('~/models/File'); -const { getLogStores } = require('~/cache'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -54,6 +51,12 @@ class BaseClient { this.outputTokensKey = 'completion_tokens'; /** @type {Set} */ this.savedMessageIds = new Set(); + /** + * Flag to determine if the client re-submitted the latest assistant message. + * @type {boolean | undefined} */ + this.continued; + /** @type {TMessage[]} */ + this.currentMessages = []; } setOptions() { @@ -589,6 +592,7 @@ class BaseClient { } else { latestMessage.text = generation; } + this.continued = true; } else { this.currentMessages.push(userMessage); } @@ -720,17 +724,6 @@ class BaseClient { this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); this.savedMessageIds.add(responseMessage.messageId); - if (responseMessage.text) { - const messageCache = getLogStores(CacheKeys.MESSAGES); - messageCache.set( - responseMessageId, - { - text: responseMessage.text, - complete: true, - }, - Time.FIVE_MINUTES, - ); - } delete responseMessage.tokenCount; return responseMessage; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 89b938b8582..8b71dcbc52c 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,6 +1,7 @@ const OpenAI = require('openai'); const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { SplitStreamHandler, GraphEvents } = require('@librechat/agents'); const { Constants, ImageDetail, @@ -28,17 +29,17 @@ const { createContextHandlers, } = require('./prompts'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); const Tokenizer = require('~/server/services/Tokenizer'); const { spendTokens } = require('~/models/spendTokens'); -const { isEnabled, sleep } = require('~/server/utils'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); +const { logger, sendEvent } = require('~/config'); const ChatGPTClient = require('./ChatGPTClient'); const { summaryBuffer } = require('./memory'); const { runTitleChain } = require('./chains'); const { tokenSplit } = require('./document'); const BaseClient = require('./BaseClient'); -const { logger } = require('~/config'); class OpenAIClient extends BaseClient { constructor(apiKey, options = {}) { @@ -65,6 +66,8 @@ class OpenAIClient extends BaseClient { this.usage; /** @type {boolean|undefined} */ this.isO1Model; + /** @type {SplitStreamHandler | undefined} */ + this.streamHandler; } // TODO: PluginsClient calls this 3x, unneeded @@ -1064,11 +1067,36 @@ ${convo} }); } + getStreamText() { + if (!this.streamHandler) { + return ''; + } + + const reasoningTokens = + this.streamHandler.reasoningTokens.length > 0 + ? `:::thinking\n${this.streamHandler.reasoningTokens.join('')}\n:::\n` + : ''; + + return `${reasoningTokens}${this.streamHandler.tokens.join('')}`; + } + + getMessageMapMethod() { + /** + * @param {TMessage} msg + */ + return (msg) => { + if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { + msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } + + return msg; + }; + } + async chatCompletion({ payload, onProgress, abortController = null }) { let error = null; + let intermediateReply = []; const errorCallback = (err) => (error = err); - const intermediateReply = []; - const reasoningTokens = []; try { if (!abortController) { abortController = new AbortController(); @@ -1266,6 +1294,19 @@ ${convo} reasoningKey = 'reasoning'; } + this.streamHandler = new SplitStreamHandler({ + reasoningKey, + accumulate: true, + runId: this.responseMessageId, + handlers: { + [GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event), + [GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event), + }, + }); + + intermediateReply = this.streamHandler.tokens; + if (modelOptions.stream) { streamPromise = new Promise((resolve) => { streamResolve = resolve; @@ -1292,41 +1333,36 @@ ${convo} } if (typeof finalMessage.content !== 'string' || finalMessage.content.trim() === '') { - finalChatCompletion.choices[0].message.content = intermediateReply.join(''); + finalChatCompletion.choices[0].message.content = this.streamHandler.tokens.join(''); } }) .on('finalMessage', (message) => { if (message?.role !== 'assistant') { - stream.messages.push({ role: 'assistant', content: intermediateReply.join('') }); + stream.messages.push({ + role: 'assistant', + content: this.streamHandler.tokens.join(''), + }); UnexpectedRoleError = true; } }); - let reasoningCompleted = false; - for await (const chunk of stream) { - if (chunk?.choices?.[0]?.delta?.[reasoningKey]) { - if (reasoningTokens.length === 0) { - const thinkingDirective = '\n'; - intermediateReply.push(thinkingDirective); - reasoningTokens.push(thinkingDirective); - onProgress(thinkingDirective); - } - const reasoning_content = chunk?.choices?.[0]?.delta?.[reasoningKey] || ''; - intermediateReply.push(reasoning_content); - reasoningTokens.push(reasoning_content); - onProgress(reasoning_content); - } - - const token = chunk?.choices?.[0]?.delta?.content || ''; - if (!reasoningCompleted && reasoningTokens.length > 0 && token) { - reasoningCompleted = true; - const separatorTokens = '\n\n'; - reasoningTokens.push(separatorTokens); - onProgress(separatorTokens); - } + if (this.continued === true) { + const latestText = addSpaceIfNeeded( + this.currentMessages[this.currentMessages.length - 1]?.text ?? '', + ); + this.streamHandler.handle({ + choices: [ + { + delta: { + content: latestText, + }, + }, + ], + }); + } - intermediateReply.push(token); - onProgress(token); + for await (const chunk of stream) { + this.streamHandler.handle(chunk); if (abortController.signal.aborted) { stream.controller.abort(); break; @@ -1369,7 +1405,7 @@ ${convo} if (!Array.isArray(choices) || choices.length === 0) { logger.warn('[OpenAIClient] Chat completion response has no choices'); - return intermediateReply.join(''); + return this.streamHandler.tokens.join(''); } const { message, finish_reason } = choices[0] ?? {}; @@ -1379,11 +1415,11 @@ ${convo} if (!message) { logger.warn('[OpenAIClient] Message is undefined in chatCompletion response'); - return intermediateReply.join(''); + return this.streamHandler.tokens.join(''); } if (typeof message.content !== 'string' || message.content.trim() === '') { - const reply = intermediateReply.join(''); + const reply = this.streamHandler.tokens.join(''); logger.debug( '[OpenAIClient] chatCompletion: using intermediateReply due to empty message.content', { intermediateReply: reply }, @@ -1391,8 +1427,18 @@ ${convo} return reply; } - if (reasoningTokens.length > 0 && this.options.context !== 'title') { - return reasoningTokens.join('') + message.content; + if ( + this.streamHandler.reasoningTokens.length > 0 && + this.options.context !== 'title' && + !message.content.startsWith('') + ) { + return this.getStreamText(); + } else if ( + this.streamHandler.reasoningTokens.length > 0 && + this.options.context !== 'title' && + message.content.startsWith('') + ) { + return message.content.replace('', ':::thinking').replace('', ':::'); } return message.content; diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 0e518bfea02..c15258f98c1 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -1,5 +1,4 @@ const OpenAIClient = require('./OpenAIClient'); -const { CacheKeys, Time } = require('librechat-data-provider'); const { CallbackManager } = require('@langchain/core/callbacks/manager'); const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); @@ -11,7 +10,6 @@ const checkBalance = require('~/models/checkBalance'); const { isEnabled } = require('~/server/utils'); const { extractBaseURL } = require('~/utils'); const { loadTools } = require('./tools/util'); -const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); class PluginsClient extends OpenAIClient { @@ -256,17 +254,6 @@ class PluginsClient extends OpenAIClient { } this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); - if (responseMessage.text) { - const messageCache = getLogStores(CacheKeys.MESSAGES); - messageCache.set( - responseMessage.messageId, - { - text: responseMessage.text, - complete: true, - }, - Time.FIVE_MINUTES, - ); - } delete responseMessage.tokenCount; return { ...responseMessage, ...result }; } diff --git a/api/config/index.js b/api/config/index.js index c66d92ae434..c2b21cfc079 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -16,7 +16,22 @@ async function getMCPManager() { return mcpManager; } +/** + * Sends message data in Server Sent Events format. + * @param {ServerResponse} res - The server response. + * @param {{ data: string | Record, event?: string }} event - The message event. + * @param {string} event.event - The type of event. + * @param {string} event.data - The message to be sent. + */ +const sendEvent = (res, event) => { + if (typeof event.data === 'string' && event.data.length === 0) { + return; + } + res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); +}; + module.exports = { logger, + sendEvent, getMCPManager, }; diff --git a/api/package.json b/api/package.json index fe8b1f1f280..f7104d89445 100644 --- a/api/package.json +++ b/api/package.json @@ -41,10 +41,10 @@ "@keyv/redis": "^2.8.1", "@langchain/community": "^0.3.14", "@langchain/core": "^0.3.18", - "@langchain/google-genai": "^0.1.6", - "@langchain/google-vertexai": "^0.1.6", + "@langchain/google-genai": "^0.1.7", + "@langchain/google-vertexai": "^0.1.8", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^1.9.94", + "@librechat/agents": "^1.9.97", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "^1.7.7", "bcryptjs": "^2.4.3", diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 6534d6b3b32..b952ab00426 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,8 +1,6 @@ -const throttle = require('lodash/throttle'); -const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); -const { getLogStores } = require('~/cache'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -57,33 +55,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { try { const { client } = await initializeClient({ req, res, endpointOption }); - const messageCache = getLogStores(CacheKeys.MESSAGES); - const { onProgress: progressCallback, getPartialText } = createOnProgress({ - onProgress: throttle( - ({ text: partialText }) => { - /* - const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; - messageCache.set(responseMessageId, { - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - model: client.modelOptions.model, - unfinished, - error: false, - user, - }, Time.FIVE_MINUTES); - */ - - messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES); - }, - 3000, - { trailing: false }, - ), - }); + const { onProgress: progressCallback, getPartialText } = createOnProgress(); - getText = getPartialText; + getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; const getAbortData = () => ({ sender, @@ -91,7 +65,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { userMessagePromise, messageId: responseMessageId, parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), + text: getText(), userMessage, promptTokens, }); diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 28fe2c4fea1..ec618eabcf8 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,8 +1,6 @@ -const throttle = require('lodash/throttle'); -const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider'); +const { getResponseSender } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); -const { getLogStores } = require('~/cache'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -53,61 +51,43 @@ const EditController = async (req, res, next, initializeClient) => { } }; - const messageCache = getLogStores(CacheKeys.MESSAGES); const { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, - onProgress: throttle( - ({ text: partialText }) => { - /* - const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true; - { - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished, - isEdited: true, - error: false, - user, - } */ - messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES); - }, - 3000, - { trailing: false }, - ), }); - const getAbortData = () => ({ - conversationId, - userMessagePromise, - messageId: responseMessageId, - sender, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - userMessage, - promptTokens, - }); + let getText; - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); + try { + const { client } = await initializeClient({ req, res, endpointOption }); - res.on('close', () => { - logger.debug('[EditController] Request closed'); - if (!abortController) { - return; - } else if (abortController.signal.aborted) { - return; - } else if (abortController.requestCompleted) { - return; - } + getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; - abortController.abort(); - logger.debug('[EditController] Request aborted on close'); - }); + const getAbortData = () => ({ + conversationId, + userMessagePromise, + messageId: responseMessageId, + sender, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getText(), + userMessage, + promptTokens, + }); - try { - const { client } = await initializeClient({ req, res, endpointOption }); + const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); + + res.on('close', () => { + logger.debug('[EditController] Request closed'); + if (!abortController) { + return; + } else if (abortController.signal.aborted) { + return; + } else if (abortController.requestCompleted) { + return; + } + + abortController.abort(); + logger.debug('[EditController] Request aborted on close'); + }); let response = await client.sendMessage(text, { user, @@ -153,7 +133,7 @@ const EditController = async (req, res, next, initializeClient) => { ); } } catch (error) { - const partialText = getPartialText(); + const partialText = getText(); handleAbortError(res, req, error, { partialText, conversationId, diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 706b9db83d7..53b45d3b6d6 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -10,7 +10,7 @@ const { const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { saveBase64Image } = require('~/server/services/Files/process'); const { loadAuthValues } = require('~/app/clients/tools/util'); -const { logger } = require('~/config'); +const { logger, sendEvent } = require('~/config'); /** @typedef {import('@librechat/agents').Graph} Graph */ /** @typedef {import('@librechat/agents').EventHandler} EventHandler */ @@ -21,20 +21,6 @@ const { logger } = require('~/config'); /** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */ /** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */ -/** - * Sends message data in Server Sent Events format. - * @param {ServerResponse} res - The server response. - * @param {{ data: string | Record, event?: string }} event - The message event. - * @param {string} event.event - The type of event. - * @param {string} event.data - The message to be sent. - */ -const sendEvent = (res, event) => { - if (typeof event.data === 'string' && event.data.length === 0) { - return; - } - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); -}; - class ModelEndHandler { /** * @param {Array} collectedUsage @@ -322,7 +308,6 @@ function createToolEndCallback({ req, res, artifactPromises }) { } module.exports = { - sendEvent, getDefaultHandlers, createToolEndCallback, }; diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 047a413433a..24a8e38fa4d 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -397,18 +397,6 @@ const chatV2 = async (req, res) => { response = streamRunManager; response.text = streamRunManager.intermediateText; - - if (response.text) { - const messageCache = getLogStores(CacheKeys.MESSAGES); - messageCache.set( - responseMessageId, - { - complete: true, - text: response.text, - }, - Time.FIVE_MINUTES, - ); - } }; await processRun(); diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 8408e7a5d6d..036654f845b 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -1,11 +1,9 @@ const express = require('express'); -const throttle = require('lodash/throttle'); -const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider'); +const { getResponseSender, Constants } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { addTitle } = require('~/server/services/Endpoints/openAI'); const { saveMessage, updateMessage } = require('~/models'); -const { getLogStores } = require('~/cache'); const { handleAbort, createAbortController, @@ -72,15 +70,6 @@ router.post( } }; - const messageCache = getLogStores(CacheKeys.MESSAGES); - const throttledCacheSet = throttle( - (text) => { - messageCache.set(responseMessageId, text, Time.FIVE_MINUTES); - }, - 3000, - { trailing: false }, - ); - let streaming = null; let timer = null; @@ -89,13 +78,11 @@ router.post( sendIntermediateMessage, getPartialText, } = createOnProgress({ - onProgress: ({ text: partialText }) => { + onProgress: () => { if (timer) { clearTimeout(timer); } - throttledCacheSet(partialText); - streaming = new Promise((resolve) => { timer = setTimeout(() => { resolve(); diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index cd4e575e836..5547a1fcdf9 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -1,6 +1,5 @@ const express = require('express'); -const throttle = require('lodash/throttle'); -const { getResponseSender, CacheKeys, Time } = require('librechat-data-provider'); +const { getResponseSender } = require('librechat-data-provider'); const { setHeaders, handleAbort, @@ -14,7 +13,6 @@ const { const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { saveMessage, updateMessage } = require('~/models'); -const { getLogStores } = require('~/cache'); const { validateTools } = require('~/app'); const { logger } = require('~/config'); @@ -80,26 +78,16 @@ router.post( } }; - const messageCache = getLogStores(CacheKeys.MESSAGES); - const throttledCacheSet = throttle( - (text) => { - messageCache.set(responseMessageId, text, Time.FIVE_MINUTES); - }, - 3000, - { trailing: false }, - ); - const { onProgress: progressCallback, sendIntermediateMessage, getPartialText, } = createOnProgress({ generation, - onProgress: ({ text: partialText }) => { + onProgress: () => { if (plugin.loading === true) { plugin.loading = false; } - throttledCacheSet(partialText); }, }); diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 770cb0f67e5..54c4aab1c2d 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -21,7 +21,7 @@ router.post('/artifact/:messageId', async (req, res) => { const { messageId } = req.params; const { index, original, updated } = req.body; - if (typeof index !== 'number' || index < 0 || !original || !updated) { + if (typeof index !== 'number' || index < 0 || original == null || updated == null) { return res.status(400).json({ error: 'Invalid request parameters' }); } diff --git a/api/server/services/Artifacts/update.js b/api/server/services/Artifacts/update.js index f17c2c0b678..69cb4bb5c43 100644 --- a/api/server/services/Artifacts/update.js +++ b/api/server/services/Artifacts/update.js @@ -57,14 +57,42 @@ const findAllArtifacts = (message) => { const replaceArtifactContent = (originalText, artifact, original, updated) => { const artifactContent = artifact.text.substring(artifact.start, artifact.end); - const relativeIndex = artifactContent.indexOf(original); + + // Find boundaries between ARTIFACT_START and ARTIFACT_END + const contentStart = artifactContent.indexOf('\n', artifactContent.indexOf(ARTIFACT_START)) + 1; + const contentEnd = artifactContent.lastIndexOf(ARTIFACT_END); + + if (contentStart === -1 || contentEnd === -1) { + return null; + } + + // Check if there are code blocks + const codeBlockStart = artifactContent.indexOf('```\n', contentStart); + const codeBlockEnd = artifactContent.lastIndexOf('\n```', contentEnd); + + // Determine where to look for the original content + let searchStart, searchEnd; + if (codeBlockStart !== -1 && codeBlockEnd !== -1) { + // If code blocks exist, search between them + searchStart = codeBlockStart + 4; // after ```\n + searchEnd = codeBlockEnd; + } else { + // Otherwise search in the whole artifact content + searchStart = contentStart; + searchEnd = contentEnd; + } + + const innerContent = artifactContent.substring(searchStart, searchEnd); + // Remove trailing newline from original for comparison + const originalTrimmed = original.replace(/\n$/, ''); + const relativeIndex = innerContent.indexOf(originalTrimmed); if (relativeIndex === -1) { return null; } - const absoluteIndex = artifact.start + relativeIndex; - const endText = originalText.substring(absoluteIndex + original.length); + const absoluteIndex = artifact.start + searchStart + relativeIndex; + const endText = originalText.substring(absoluteIndex + originalTrimmed.length); const hasTrailingNewline = endText.startsWith('\n'); const updatedText = diff --git a/api/server/services/Artifacts/update.spec.js b/api/server/services/Artifacts/update.spec.js index 8008e553baf..2f5b9d7bf64 100644 --- a/api/server/services/Artifacts/update.spec.js +++ b/api/server/services/Artifacts/update.spec.js @@ -260,8 +260,61 @@ console.log(greeting);`; codeExample, 'updated content', ); - console.log(result); expect(result).toMatch(/id="2".*updated content/s); expect(result).toMatch(new RegExp(`${ARTIFACT_START}.*updated content.*${ARTIFACT_END}`, 's')); }); + + test('should handle empty content in artifact without code blocks', () => { + const artifactText = `${ARTIFACT_START}\n\n${ARTIFACT_END}`; + const artifact = { + start: 0, + end: artifactText.length, + text: artifactText, + source: 'text', + }; + + const result = replaceArtifactContent(artifactText, artifact, '', 'new content'); + expect(result).toBe(`${ARTIFACT_START}\nnew content\n${ARTIFACT_END}`); + }); + + test('should handle empty content in artifact with code blocks', () => { + const artifactText = createArtifactText({ content: '' }); + const artifact = { + start: 0, + end: artifactText.length, + text: artifactText, + source: 'text', + }; + + const result = replaceArtifactContent(artifactText, artifact, '', 'new content'); + expect(result).toMatch(/```\nnew content\n```/); + }); + + test('should handle content with trailing newline in code blocks', () => { + const contentWithNewline = 'console.log("test")\n'; + const message = { + text: `Some prefix text\n${createArtifactText({ + content: contentWithNewline, + })}\nSome suffix text`, + }; + + const artifacts = findAllArtifacts(message); + expect(artifacts).toHaveLength(1); + + const result = replaceArtifactContent( + message.text, + artifacts[0], + contentWithNewline, + 'updated content', + ); + + // Should update the content and preserve artifact structure + expect(result).toContain('```\nupdated content\n```'); + // Should preserve surrounding text + expect(result).toMatch(/^Some prefix text\n/); + expect(result).toMatch(/\nSome suffix text$/); + // Should not have extra newlines + expect(result).not.toContain('\n\n```'); + expect(result).not.toContain('```\n\n'); + }); }); diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index bfb90843da4..cd718fdfc15 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -364,7 +364,7 @@ class TTSService { shouldContinue = false; }); - const processChunks = createChunkProcessor(req.body.messageId); + const processChunks = createChunkProcessor(req.user.id, req.body.messageId); try { while (shouldContinue) { diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index 7b6bef03f84..ac046e68a65 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,4 +1,5 @@ -const { CacheKeys, findLastSeparatorIndex, SEPARATORS } = require('librechat-data-provider'); +const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider'); +const { getMessage } = require('~/models/Message'); const { getLogStores } = require('~/cache'); /** @@ -47,10 +48,11 @@ const MAX_NOT_FOUND_COUNT = 6; const MAX_NO_CHANGE_COUNT = 10; /** + * @param {string} user * @param {string} messageId * @returns {() => Promise<{ text: string, isFinished: boolean }[]>} */ -function createChunkProcessor(messageId) { +function createChunkProcessor(user, messageId) { let notFoundCount = 0; let noChangeCount = 0; let processedText = ''; @@ -73,15 +75,27 @@ function createChunkProcessor(messageId) { } /** @type { string | { text: string; complete: boolean } } */ - const message = await messageCache.get(messageId); + let message = await messageCache.get(messageId); + if (!message) { + message = await getMessage({ user, messageId }); + } if (!message) { notFoundCount++; return []; + } else { + messageCache.set( + messageId, + { + text: message.text, + complete: true, + }, + Time.FIVE_MINUTES, + ); } const text = typeof message === 'string' ? message : message.text; - const complete = typeof message === 'string' ? false : message.complete; + const complete = typeof message === 'string' ? false : message.complete ?? true; if (text === processedText) { noChangeCount++; diff --git a/api/server/services/Files/Audio/streamAudio.spec.js b/api/server/services/Files/Audio/streamAudio.spec.js index 501e252c14b..e76c0849c7f 100644 --- a/api/server/services/Files/Audio/streamAudio.spec.js +++ b/api/server/services/Files/Audio/streamAudio.spec.js @@ -3,6 +3,13 @@ const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); jest.mock('keyv'); const globalCache = {}; +jest.mock('~/models/Message', () => { + return { + getMessage: jest.fn().mockImplementation((messageId) => { + return globalCache[messageId] || null; + }), + }; +}); jest.mock('~/cache/getLogStores', () => { return jest.fn().mockImplementation(() => { const EventEmitter = require('events'); @@ -56,9 +63,10 @@ describe('processChunks', () => { jest.resetAllMocks(); mockMessageCache = { get: jest.fn(), + set: jest.fn(), }; require('~/cache/getLogStores').mockReturnValue(mockMessageCache); - processChunks = createChunkProcessor('message-id'); + processChunks = createChunkProcessor('userId', 'message-id'); }); it('should return an empty array when the message is not found', async () => { diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js index ae00659983e..4bab7326bb0 100644 --- a/api/server/services/Runs/StreamRunManager.js +++ b/api/server/services/Runs/StreamRunManager.js @@ -1,19 +1,15 @@ -const throttle = require('lodash/throttle'); const { - Time, - CacheKeys, + Constants, StepTypes, ContentTypes, ToolCallTypes, MessageContentTypes, AssistantStreamEvents, - Constants, } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { processMessages } = require('~/server/services/Threads'); -const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); /** @@ -611,20 +607,8 @@ class StreamRunManager { const index = this.getStepIndex(stepKey); this.orderedRunSteps.set(index, message_creation); - const messageCache = getLogStores(CacheKeys.MESSAGES); - // Create the Factory Function to stream the message - const { onProgress: progressCallback } = createOnProgress({ - onProgress: throttle( - () => { - messageCache.set(this.finalMessage.messageId, this.getText(), Time.FIVE_MINUTES); - }, - 3000, - { trailing: false }, - ), - }); + const { onProgress: progressCallback } = createOnProgress(); - // This creates a function that attaches all of the parameters - // specified here to each SSE message generated by the TextStream const onProgress = progressCallback({ index, res: this.res, diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 92f8253fc73..9567944d1cc 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -18,7 +18,12 @@ const citationRegex = /\[\^\d+?\^]/g; const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); const base = { message: true, initial: true }; -const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { +const createOnProgress = ( + { generation = '', onProgress: _onProgress } = { + generation: '', + onProgress: null, + }, +) => { let i = 0; let tokens = addSpaceIfNeeded(generation); diff --git a/client/public/assets/silence.mp3 b/client/public/assets/silence.mp3 new file mode 100644 index 00000000000..482395f021c Binary files /dev/null and b/client/public/assets/silence.mp3 differ diff --git a/client/src/App.jsx b/client/src/App.jsx index e2b11b261f9..38e568e4225 100644 --- a/client/src/App.jsx +++ b/client/src/App.jsx @@ -49,5 +49,14 @@ const App = () => { export default () => ( +