Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🏄‍♂️ refactor: Optimize Reasoning UI & Token Streaming #5546

Merged
merged 23 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
29ae39b
✨ feat: Implement Show Thinking feature; refactor: testing thinking r…
berry-13 Jan 27, 2025
8c240e7
✨ feat: Refactor Thinking component styles and enhance Markdown rende…
berry-13 Jan 27, 2025
e596854
chore: add back removed code, revert type changes
danny-avila Jan 28, 2025
d07cb3e
chore: Add back resetCounter effect to Markdown component for improve…
danny-avila Jan 28, 2025
e92f0d3
chore: bump @librechat/agents and google langchain packages
danny-avila Jan 29, 2025
0ee8e74
WIP: reasoning type updates
danny-avila Jan 29, 2025
51d8378
WIP: first pass, reasoning content blocks
danny-avila Jan 29, 2025
5e59612
chore: revert code
danny-avila Jan 29, 2025
f380ec6
chore: bump @librechat/agents
danny-avila Jan 29, 2025
160a455
refactor: optimize reasoning tag handling
danny-avila Jan 29, 2025
74422f5
style: ul indent padding
danny-avila Jan 29, 2025
e23fd43
feat: add Reasoning component to handle reasoning display
danny-avila Jan 29, 2025
0f5b68a
feat: first pass, content reasoning part styling
danny-avila Jan 29, 2025
910f7ae
refactor: add content placeholder for endpoints using new stream handler
danny-avila Jan 29, 2025
77410de
refactor: only cache messages when requesting stream audio
danny-avila Jan 29, 2025
bc12b78
fix: circular dep.
danny-avila Jan 29, 2025
c091c3a
fix: add default param
danny-avila Jan 29, 2025
24413fa
refactor: tts, only request after message stream, fix chrome autoplay
danny-avila Jan 29, 2025
0827475
style: update label for submitting state and add localization for 'Th…
danny-avila Jan 29, 2025
5e9a266
fix: improve global audio pause logic and reset active run ID
danny-avila Jan 29, 2025
2a66e57
fix: handle artifact edge cases
danny-avila Jan 29, 2025
bdad35b
fix: remove unnecessary console log from artifact update test
danny-avila Jan 29, 2025
704363f
feat: add support for continued message handling with new streaming m…
danny-avila Jan 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ const {
EModelEndpoint,
ErrorTypes,
Constants,
CacheKeys,
Time,
} = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
const TextStream = require('./TextStream');
const { logger } = require('~/config');

Expand Down Expand Up @@ -54,6 +51,12 @@ class BaseClient {
this.outputTokensKey = 'completion_tokens';
/** @type {Set<string>} */
this.savedMessageIds = new Set();
/**
* Flag to determine if the client re-submitted the latest assistant message.
* @type {boolean | undefined} */
this.continued;
/** @type {TMessage[]} */
this.currentMessages = [];
}

setOptions() {
Expand Down Expand Up @@ -589,6 +592,7 @@ class BaseClient {
} else {
latestMessage.text = generation;
}
this.continued = true;
} else {
this.currentMessages.push(userMessage);
}
Expand Down Expand Up @@ -720,17 +724,6 @@ class BaseClient {

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.savedMessageIds.add(responseMessage.messageId);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return responseMessage;
}
Expand Down
116 changes: 81 additions & 35 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const OpenAI = require('openai');
const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { SplitStreamHandler, GraphEvents } = require('@librechat/agents');
const {
Constants,
ImageDetail,
Expand Down Expand Up @@ -28,17 +29,17 @@ const {
createContextHandlers,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const { isEnabled, sleep } = require('~/server/utils');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const { logger, sendEvent } = require('~/config');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');

class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
Expand All @@ -65,6 +66,8 @@ class OpenAIClient extends BaseClient {
this.usage;
/** @type {boolean|undefined} */
this.isO1Model;
/** @type {SplitStreamHandler | undefined} */
this.streamHandler;
}

// TODO: PluginsClient calls this 3x, unneeded
Expand Down Expand Up @@ -1064,11 +1067,36 @@ ${convo}
});
}

getStreamText() {
if (!this.streamHandler) {
return '';
}

const reasoningTokens =
this.streamHandler.reasoningTokens.length > 0
? `:::thinking\n${this.streamHandler.reasoningTokens.join('')}\n:::\n`
: '';

return `${reasoningTokens}${this.streamHandler.tokens.join('')}`;
}

getMessageMapMethod() {
/**
* @param {TMessage} msg
*/
return (msg) => {
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
}

return msg;
};
}

async chatCompletion({ payload, onProgress, abortController = null }) {
let error = null;
let intermediateReply = [];
const errorCallback = (err) => (error = err);
const intermediateReply = [];
const reasoningTokens = [];
try {
if (!abortController) {
abortController = new AbortController();
Expand Down Expand Up @@ -1266,6 +1294,19 @@ ${convo}
reasoningKey = 'reasoning';
}

this.streamHandler = new SplitStreamHandler({
reasoningKey,
accumulate: true,
runId: this.responseMessageId,
handlers: {
[GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event),
},
});

intermediateReply = this.streamHandler.tokens;

if (modelOptions.stream) {
streamPromise = new Promise((resolve) => {
streamResolve = resolve;
Expand All @@ -1292,41 +1333,36 @@ ${convo}
}

if (typeof finalMessage.content !== 'string' || finalMessage.content.trim() === '') {
finalChatCompletion.choices[0].message.content = intermediateReply.join('');
finalChatCompletion.choices[0].message.content = this.streamHandler.tokens.join('');
}
})
.on('finalMessage', (message) => {
if (message?.role !== 'assistant') {
stream.messages.push({ role: 'assistant', content: intermediateReply.join('') });
stream.messages.push({
role: 'assistant',
content: this.streamHandler.tokens.join(''),
});
UnexpectedRoleError = true;
}
});

let reasoningCompleted = false;
for await (const chunk of stream) {
if (chunk?.choices?.[0]?.delta?.[reasoningKey]) {
if (reasoningTokens.length === 0) {
const thinkingDirective = '<think>\n';
intermediateReply.push(thinkingDirective);
reasoningTokens.push(thinkingDirective);
onProgress(thinkingDirective);
}
const reasoning_content = chunk?.choices?.[0]?.delta?.[reasoningKey] || '';
intermediateReply.push(reasoning_content);
reasoningTokens.push(reasoning_content);
onProgress(reasoning_content);
}

const token = chunk?.choices?.[0]?.delta?.content || '';
if (!reasoningCompleted && reasoningTokens.length > 0 && token) {
reasoningCompleted = true;
const separatorTokens = '\n</think>\n';
reasoningTokens.push(separatorTokens);
onProgress(separatorTokens);
}
if (this.continued === true) {
const latestText = addSpaceIfNeeded(
this.currentMessages[this.currentMessages.length - 1]?.text ?? '',
);
this.streamHandler.handle({
choices: [
{
delta: {
content: latestText,
},
},
],
});
}

intermediateReply.push(token);
onProgress(token);
for await (const chunk of stream) {
this.streamHandler.handle(chunk);
if (abortController.signal.aborted) {
stream.controller.abort();
break;
Expand Down Expand Up @@ -1369,7 +1405,7 @@ ${convo}

if (!Array.isArray(choices) || choices.length === 0) {
logger.warn('[OpenAIClient] Chat completion response has no choices');
return intermediateReply.join('');
return this.streamHandler.tokens.join('');
}

const { message, finish_reason } = choices[0] ?? {};
Expand All @@ -1379,20 +1415,30 @@ ${convo}

if (!message) {
logger.warn('[OpenAIClient] Message is undefined in chatCompletion response');
return intermediateReply.join('');
return this.streamHandler.tokens.join('');
}

if (typeof message.content !== 'string' || message.content.trim() === '') {
const reply = intermediateReply.join('');
const reply = this.streamHandler.tokens.join('');
logger.debug(
'[OpenAIClient] chatCompletion: using intermediateReply due to empty message.content',
{ intermediateReply: reply },
);
return reply;
}

if (reasoningTokens.length > 0 && this.options.context !== 'title') {
return reasoningTokens.join('') + message.content;
if (
this.streamHandler.reasoningTokens.length > 0 &&
this.options.context !== 'title' &&
!message.content.startsWith('<think>')
) {
return this.getStreamText();
} else if (
this.streamHandler.reasoningTokens.length > 0 &&
this.options.context !== 'title' &&
message.content.startsWith('<think>')
) {
return message.content.replace('<think>', ':::thinking').replace('</think>', ':::');
}

return message.content;
Expand Down
13 changes: 0 additions & 13 deletions api/app/clients/PluginsClient.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const OpenAIClient = require('./OpenAIClient');
const { CacheKeys, Time } = require('librechat-data-provider');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
Expand All @@ -11,7 +10,6 @@ const checkBalance = require('~/models/checkBalance');
const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');

class PluginsClient extends OpenAIClient {
Expand Down Expand Up @@ -256,17 +254,6 @@ class PluginsClient extends OpenAIClient {
}

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}
Expand Down
15 changes: 15 additions & 0 deletions api/config/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,22 @@ async function getMCPManager() {
return mcpManager;
}

/**
* Sends message data in Server Sent Events format.
* @param {ServerResponse} res - The server response.
* @param {{ data: string | Record<string, unknown>, event?: string }} event - The message event.
* @param {string} event.event - The type of event.
* @param {string} event.data - The message to be sent.
*/
const sendEvent = (res, event) => {
if (typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};

module.exports = {
logger,
sendEvent,
getMCPManager,
};
6 changes: 3 additions & 3 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.3.14",
"@langchain/core": "^0.3.18",
"@langchain/google-genai": "^0.1.6",
"@langchain/google-vertexai": "^0.1.6",
"@langchain/google-genai": "^0.1.7",
"@langchain/google-vertexai": "^0.1.8",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^1.9.94",
"@librechat/agents": "^1.9.97",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.7.7",
"bcryptjs": "^2.4.3",
Expand Down
34 changes: 4 additions & 30 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');

Expand Down Expand Up @@ -57,41 +55,17 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {

try {
const { client } = await initializeClient({ req, res, endpointOption });
const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: client.modelOptions.model,
unfinished,
error: false,
user,
}, Time.FIVE_MINUTES);
*/

messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },
),
});
const { onProgress: progressCallback, getPartialText } = createOnProgress();

getText = getPartialText;
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;

const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
text: getText(),
userMessage,
promptTokens,
});
Expand Down
Loading
Loading