From 514a502b9c3a649f7682b937313d181f8ba21e49 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 23 May 2024 16:27:36 -0400 Subject: [PATCH] =?UTF-8?q?=E2=8F=AF=EF=B8=8F=20fix(tts):=20Resolve=20Voic?= =?UTF-8?q?e=20Selection=20and=20Manual=20Playback=20Issues=20(#2845)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: voice setting for autoplayback TTS * fix(useTextToSpeechExternal): resolve stateful playback issues and consolidate state logic * refactor: initialize tts voice and provider schema once per request * fix(tts): edge case, longer text inputs. TODO: use continuous stream for longer text inputs * fix(tts): pause global audio on conversation change * refactor: keyvMongo ban cache to allow db updates for unbanning, to prevent server restart * chore: eslint fix * refactor: make ban cache exclusively keyvMongo --- api/server/middleware/checkBan.js | 6 +- .../services/Files/Audio/streamAudio.js | 61 ++++- .../services/Files/Audio/streamAudio.spec.js | 51 ++++- .../services/Files/Audio/textToSpeech.js | 210 ++++++++++-------- api/server/services/start/assistants.js | 1 - api/typedefs.js | 6 + .../src/components/Chat/Input/StreamAudio.tsx | 15 +- .../hooks/Input/useTextToSpeechExternal.ts | 83 +++---- client/src/localization/languages/Eng.ts | 3 + packages/data-provider/src/config.ts | 70 +++--- 10 files changed, 330 insertions(+), 176 deletions(-) diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index aa322cd1c2e..ee7d7171de0 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -2,14 +2,12 @@ const Keyv = require('keyv'); const uap = require('ua-parser-js'); const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, removePorts } = require('../utils'); -const keyvRedis = require('~/cache/keyvRedis'); +const keyvMongo = require('~/cache/keyvMongo'); const denyRequest = require('./denyRequest'); const { getLogStores } = require('~/cache'); const User = require('~/models/User'); -const banCache = isEnabled(process.env.USE_REDIS) - ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: ViolationTypes.BAN, ttl: 0 }); +const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 }); const message = 'Your account has been temporarily banned due to violations of our service.'; /** diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index 4b74816d28c..9f301e710bf 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -90,7 +90,7 @@ function findLastSeparatorIndex(text, separators = SEPARATORS) { } const MAX_NOT_FOUND_COUNT = 6; -const MAX_NO_CHANGE_COUNT = 12; +const MAX_NO_CHANGE_COUNT = 10; /** * @param {string} messageId @@ -152,6 +152,64 @@ function createChunkProcessor(messageId) { return processChunks; } +/** + * @param {string} text + * @param {number} [chunkSize=4000] + * @returns {{ text: string, isFinished: boolean }[]} + */ +function splitTextIntoChunks(text, chunkSize = 4000) { + if (!text) { + throw new Error('Text is required'); + } + + const chunks = []; + let startIndex = 0; + const textLength = text.length; + + while (startIndex < textLength) { + let endIndex = Math.min(startIndex + chunkSize, textLength); + let chunkText = text.slice(startIndex, endIndex); + + if (endIndex < textLength) { + let lastSeparatorIndex = -1; + for (const separator of SEPARATORS) { + const index = chunkText.lastIndexOf(separator); + if (index !== -1) { + lastSeparatorIndex = Math.max(lastSeparatorIndex, index); + } + } + + if (lastSeparatorIndex !== -1) { + endIndex = startIndex + lastSeparatorIndex + 1; + chunkText = text.slice(startIndex, endIndex); + } else { + const nextSeparatorIndex = text.slice(endIndex).search(/\S/); + if (nextSeparatorIndex !== -1) { + endIndex += nextSeparatorIndex; + chunkText = text.slice(startIndex, endIndex); + } + } + } + + chunkText = chunkText.trim(); + if (chunkText) { + chunks.push({ + text: chunkText, + isFinished: endIndex >= textLength, + }); + } else if (chunks.length > 0) { + chunks[chunks.length - 1].isFinished = true; + } + + startIndex = endIndex; + while (startIndex < textLength && text[startIndex].trim() === '') { + startIndex++; + } + } + + return chunks; +} + /** * Input stream text to speech * @param {Express.Response} res @@ -307,6 +365,7 @@ module.exports = { inputStreamTextToSpeech, findLastSeparatorIndex, createChunkProcessor, + splitTextIntoChunks, llmMessageSource, getRandomVoiceId, }; diff --git a/api/server/services/Files/Audio/streamAudio.spec.js b/api/server/services/Files/Audio/streamAudio.spec.js index 6aee27c7b81..7aff8dbfa76 100644 --- a/api/server/services/Files/Audio/streamAudio.spec.js +++ b/api/server/services/Files/Audio/streamAudio.spec.js @@ -1,5 +1,5 @@ +const { createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { Message } = require('~/models/Message'); -const { createChunkProcessor } = require('./streamAudio'); jest.mock('~/models/Message', () => ({ Message: { @@ -86,3 +86,52 @@ describe('processChunks', () => { expect(Message.findOne().lean).toHaveBeenCalledTimes(2); }); }); + +describe('splitTextIntoChunks', () => { + test('splits text into chunks of specified size with default separators', () => { + const text = 'This is a test. This is only a test! Make sure it works properly? Okay.'; + const chunkSize = 20; + const expectedChunks = [ + { text: 'This is a test.', isFinished: false }, + { text: 'This is only a test!', isFinished: false }, + { text: 'Make sure it works p', isFinished: false }, + { text: 'roperly? Okay.', isFinished: true }, + ]; + + const result = splitTextIntoChunks(text, chunkSize); + expect(result).toEqual(expectedChunks); + }); + + test('splits text into chunks with default size', () => { + const text = 'A'.repeat(8000) + '. The end.'; + const expectedChunks = [ + { text: 'A'.repeat(4000), isFinished: false }, + { text: 'A'.repeat(4000), isFinished: false }, + { text: '. The end.', isFinished: true }, + ]; + + const result = splitTextIntoChunks(text); + expect(result).toEqual(expectedChunks); + }); + + test('returns a single chunk if text length is less than chunk size', () => { + const text = 'Short text.'; + const expectedChunks = [{ text: 'Short text.', isFinished: true }]; + + const result = splitTextIntoChunks(text, 4000); + expect(result).toEqual(expectedChunks); + }); + + test('handles text with no separators correctly', () => { + const text = 'ThisTextHasNoSeparatorsAndIsVeryLong'.repeat(100); + const chunkSize = 4000; + const expectedChunks = [{ text: text, isFinished: true }]; + + const result = splitTextIntoChunks(text, chunkSize); + expect(result).toEqual(expectedChunks); + }); + + test('throws an error when text is empty', () => { + expect(() => splitTextIntoChunks('')).toThrow('Text is required'); + }); +}); diff --git a/api/server/services/Files/Audio/textToSpeech.js b/api/server/services/Files/Audio/textToSpeech.js index 6c8f306c89a..2d77324ce4a 100644 --- a/api/server/services/Files/Audio/textToSpeech.js +++ b/api/server/services/Files/Audio/textToSpeech.js @@ -1,6 +1,6 @@ const axios = require('axios'); const getCustomConfig = require('~/server/services/Config/getCustomConfig'); -const { getRandomVoiceId, createChunkProcessor } = require('./streamAudio'); +const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { extractEnvVariable } = require('librechat-data-provider'); const { logger } = require('~/config'); @@ -54,7 +54,7 @@ function removeUndefined(obj) { * This function prepares the necessary data and headers for making a request to the OpenAI TTS * It uses the provided TTS schema, input text, and voice to create the request * - * @param {Object} ttsSchema - The TTS schema containing the OpenAI configuration + * @param {TCustomConfig['tts']['openai']} ttsSchema - The TTS schema containing the OpenAI configuration * @param {string} input - The text to be converted to speech * @param {string} voice - The voice to be used for the speech * @@ -62,27 +62,27 @@ function removeUndefined(obj) { * If an error occurs, it throws an error with a message indicating that the selected voice is not available */ function openAIProvider(ttsSchema, input, voice) { - const url = ttsSchema.openai?.url || 'https://api.openai.com/v1/audio/speech'; + const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech'; if ( - ttsSchema.openai?.voices && - ttsSchema.openai.voices.length > 0 && - !ttsSchema.openai.voices.includes(voice) && - !ttsSchema.openai.voices.includes('ALL') + ttsSchema?.voices && + ttsSchema.voices.length > 0 && + !ttsSchema.voices.includes(voice) && + !ttsSchema.voices.includes('ALL') ) { throw new Error(`Voice ${voice} is not available.`); } let data = { input, - model: ttsSchema.openai?.model, - voice: ttsSchema.openai?.voices && ttsSchema.openai.voices.length > 0 ? voice : undefined, - backend: ttsSchema.openai?.backend, + model: ttsSchema?.model, + voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined, + backend: ttsSchema?.backend, }; let headers = { 'Content-Type': 'application/json', - Authorization: 'Bearer ' + extractEnvVariable(ttsSchema.openai?.apiKey), + Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey), }; [data, headers].forEach(removeUndefined); @@ -95,7 +95,7 @@ function openAIProvider(ttsSchema, input, voice) { * This function prepares the necessary data and headers for making a request to the Eleven Labs TTS * It uses the provided TTS schema, input text, and voice to create the request * - * @param {Object} ttsSchema - The TTS schema containing the Eleven Labs configuration + * @param {TCustomConfig['tts']['elevenLabs']} ttsSchema - The TTS schema containing the Eleven Labs configuration * @param {string} input - The text to be converted to speech * @param {string} voice - The voice to be used for the speech * @param {boolean} stream - Whether to stream the audio or not @@ -105,34 +105,31 @@ function openAIProvider(ttsSchema, input, voice) { */ function elevenLabsProvider(ttsSchema, input, voice, stream) { let url = - ttsSchema.elevenlabs?.url || + ttsSchema?.url || `https://api.elevenlabs.io/v1/text-to-speech/{voice_id}${stream ? '/stream' : ''}`; - if ( - !ttsSchema.elevenlabs?.voices.includes(voice) && - !ttsSchema.elevenlabs?.voices.includes('ALL') - ) { + if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) { throw new Error(`Voice ${voice} is not available.`); } url = url.replace('{voice_id}', voice); let data = { - model_id: ttsSchema.elevenlabs?.model, + model_id: ttsSchema?.model, text: input, // voice_id: voice, voice_settings: { - similarity_boost: ttsSchema.elevenlabs?.voice_settings?.similarity_boost, - stability: ttsSchema.elevenlabs?.voice_settings?.stability, - style: ttsSchema.elevenlabs?.voice_settings?.style, - use_speaker_boost: ttsSchema.elevenlabs?.voice_settings?.use_speaker_boost || undefined, + similarity_boost: ttsSchema?.voice_settings?.similarity_boost, + stability: ttsSchema?.voice_settings?.stability, + style: ttsSchema?.voice_settings?.style, + use_speaker_boost: ttsSchema?.voice_settings?.use_speaker_boost || undefined, }, - pronunciation_dictionary_locators: ttsSchema.elevenlabs?.pronunciation_dictionary_locators, + pronunciation_dictionary_locators: ttsSchema?.pronunciation_dictionary_locators, }; let headers = { 'Content-Type': 'application/json', - 'xi-api-key': extractEnvVariable(ttsSchema.elevenlabs?.apiKey), + 'xi-api-key': extractEnvVariable(ttsSchema?.apiKey), Accept: 'audio/mpeg', }; @@ -146,7 +143,7 @@ function elevenLabsProvider(ttsSchema, input, voice, stream) { * This function prepares the necessary data and headers for making a request to the LocalAI TTS * It uses the provided TTS schema, input text, and voice to create the request * - * @param {Object} ttsSchema - The TTS schema containing the LocalAI configuration + * @param {TCustomConfig['tts']['localai']} ttsSchema - The TTS schema containing the LocalAI configuration * @param {string} input - The text to be converted to speech * @param {string} voice - The voice to be used for the speech * @@ -154,102 +151,78 @@ function elevenLabsProvider(ttsSchema, input, voice, stream) { * @throws {Error} Throws an error if the selected voice is not available */ function localAIProvider(ttsSchema, input, voice) { - let url = ttsSchema.localai?.url; + let url = ttsSchema?.url; if ( - ttsSchema.localai?.voices && - ttsSchema.localai.voices.length > 0 && - !ttsSchema.localai.voices.includes(voice) && - !ttsSchema.localai.voices.includes('ALL') + ttsSchema?.voices && + ttsSchema.voices.length > 0 && + !ttsSchema.voices.includes(voice) && + !ttsSchema.voices.includes('ALL') ) { throw new Error(`Voice ${voice} is not available.`); } let data = { input, - model: ttsSchema.localai?.voices && ttsSchema.localai.voices.length > 0 ? voice : undefined, - backend: ttsSchema.localai?.backend, + model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined, + backend: ttsSchema?.backend, }; let headers = { 'Content-Type': 'application/json', - Authorization: 'Bearer ' + extractEnvVariable(ttsSchema.localai?.apiKey), + Authorization: 'Bearer ' + extractEnvVariable(ttsSchema?.apiKey), }; [data, headers].forEach(removeUndefined); - if (extractEnvVariable(ttsSchema.localai.apiKey) === '') { + if (extractEnvVariable(ttsSchema.apiKey) === '') { delete headers.Authorization; } return [url, data, headers]; } -/* not used */ -/* -async function streamAudioFromWebSocket(req, res) { - const { voice } = req.body; - const customConfig = await getCustomConfig(); - - if (!customConfig) { - return res.status(500).send('Custom config not found'); - } - - const ttsSchema = customConfig.tts; - const provider = getProvider(ttsSchema); - - if (provider !== 'elevenlabs') { - return res.status(400).send('WebSocket streaming is only supported for Eleven Labs'); - } - - const url = - ttsSchema.elevenlabs.websocketUrl || - 'wss://api.elevenlabs.io/v1/text-to-speech/{voice_id}/stream-input?model_id={model}' - .replace('{voice_id}', voice) - .replace('{model}', ttsSchema.elevenlabs.model); - const ws = new WebSocket(url); - - ws.onopen = () => { - logger.debug('WebSocket connection opened'); - sendTextToWebsocket(ws, (data) => { - res.write(data); // Stream data directly to the response - }); - }; - - ws.onclose = () => { - logger.debug('WebSocket connection closed'); - res.end(); // End the response when the WebSocket is closed - }; - - ws.onerror = (error) => { - logger.error('WebSocket error:', error); - res.status(500).send('WebSocket error'); - }; +/** + * + * Returns provider and its schema for use with TTS requests + * @param {TCustomConfig} customConfig + * @param {string} _voice + * @returns {Promise<[string, TProviderSchema]>} + */ +async function getProviderSchema(customConfig) { + const provider = getProvider(customConfig.tts); + return [provider, customConfig.tts[provider]]; } -*/ /** * - * @param {TCustomConfig} customConfig - * @param {string} voice - * @returns {Promise} + * Returns a tuple of the TTS schema as well as the voice for the TTS request + * @param {TProviderSchema} providerSchema + * @param {string} requestVoice + * @returns {Promise} */ -async function ttsRequest( - customConfig, - { input, voice: _v, stream = true } = { input: '', stream: true }, -) { - const ttsSchema = customConfig.tts; - const provider = getProvider(ttsSchema); - const voices = ttsSchema[provider].voices.filter( - (voice) => voice && voice.toUpperCase() !== 'ALL', - ); - let voice = _v; +async function getVoice(providerSchema, requestVoice) { + const voices = providerSchema.voices.filter((voice) => voice && voice.toUpperCase() !== 'ALL'); + let voice = requestVoice; if (!voice || !voices.includes(voice) || (voice.toUpperCase() === 'ALL' && voices.length > 1)) { voice = getRandomVoiceId(voices); } - let [url, data, headers] = []; + return voice; +} +/** + * + * @param {string} provider + * @param {TProviderSchema} ttsSchema + * @param {object} params + * @param {string} params.voice + * @param {string} params.input + * @param {boolean} [params.stream] + * @returns {Promise} + */ +async function ttsRequest(provider, ttsSchema, { input, voice, stream = true } = { stream: true }) { + let [url, data, headers] = []; switch (provider) { case 'openai': [url, data, headers] = openAIProvider(ttsSchema, input, voice); @@ -283,7 +256,7 @@ async function ttsRequest( * @throws {Error} Throws an error if the provider is invalid */ async function textToSpeech(req, res) { - const { input, voice } = req.body; + const { input } = req.body; if (!input) { return res.status(400).send('Missing text in request body'); @@ -296,8 +269,47 @@ async function textToSpeech(req, res) { try { res.setHeader('Content-Type', 'audio/mpeg'); - const response = await ttsRequest(customConfig, { input, voice }); - response.data.pipe(res); + const [provider, ttsSchema] = await getProviderSchema(customConfig); + const voice = await getVoice(ttsSchema, req.body.voice); + if (input.length < 4096) { + const response = await ttsRequest(provider, ttsSchema, { input, voice }); + response.data.pipe(res); + return; + } + + const textChunks = splitTextIntoChunks(input, 1000); + + for (const chunk of textChunks) { + try { + const response = await ttsRequest(provider, ttsSchema, { + voice, + input: chunk.text, + stream: true, + }); + + logger.debug(`[textToSpeech] user: ${req?.user?.id} | writing audio stream`); + await new Promise((resolve) => { + response.data.pipe(res, { end: chunk.isFinished }); + response.data.on('end', () => { + resolve(); + }); + }); + + if (chunk.isFinished) { + break; + } + } catch (innerError) { + logger.error('Error processing update:', chunk, innerError); + if (!res.headersSent) { + res.status(500).end(); + } + return; + } + } + + if (!res.headersSent) { + res.end(); + } } catch (error) { logger.error('An error occurred while creating the audio stream:', error); res.status(500).send('An error occurred'); @@ -311,8 +323,17 @@ async function streamAudio(req, res) { return res.status(500).send('Custom config not found'); } + const [provider, ttsSchema] = await getProviderSchema(customConfig); + const voice = await getVoice(ttsSchema, req.body.voice); + try { let shouldContinue = true; + + req.on('close', () => { + logger.warn('[streamAudio] Audio Stream Request closed by client'); + shouldContinue = false; + }); + const processChunks = createChunkProcessor(req.body.messageId); while (shouldContinue) { @@ -337,7 +358,8 @@ async function streamAudio(req, res) { for (const update of updates) { try { - const response = await ttsRequest(customConfig, { + const response = await ttsRequest(provider, ttsSchema, { + voice, input: update.text, stream: true, }); @@ -348,7 +370,7 @@ async function streamAudio(req, res) { logger.debug(`[streamAudio] user: ${req?.user?.id} | writing audio stream`); await new Promise((resolve) => { - response.data.pipe(res, { end: false }); + response.data.pipe(res, { end: update.isFinished }); response.data.on('end', () => { resolve(); }); diff --git a/api/server/services/start/assistants.js b/api/server/services/start/assistants.js index eba6319f84c..aa43f2b7c8b 100644 --- a/api/server/services/start/assistants.js +++ b/api/server/services/start/assistants.js @@ -1,6 +1,5 @@ const { Capabilities, - EModelEndpoint, assistantEndpointSchema, defaultAssistantsVersion, } = require('librechat-data-provider'); diff --git a/api/typedefs.js b/api/typedefs.js index 5c83cab1598..b55dfa120a8 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -349,6 +349,12 @@ * @memberof typedefs */ +/** + * @exports TProviderSchema + * @typedef {import('librechat-data-provider').TProviderSchema} TProviderSchema + * @memberof typedefs + */ + /** * @exports TEndpoint * @typedef {import('librechat-data-provider').TEndpoint} TEndpoint diff --git a/client/src/components/Chat/Input/StreamAudio.tsx b/client/src/components/Chat/Input/StreamAudio.tsx index 34bd4a30397..b89c295d375 100644 --- a/client/src/components/Chat/Input/StreamAudio.tsx +++ b/client/src/components/Chat/Input/StreamAudio.tsx @@ -1,10 +1,10 @@ import { useParams } from 'react-router-dom'; +import { useEffect, useCallback } from 'react'; import { QueryKeys } from 'librechat-data-provider'; import { useQueryClient } from '@tanstack/react-query'; -import { useEffect, useCallback } from 'react'; import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil'; import type { TMessage } from 'librechat-data-provider'; -import { useCustomAudioRef, MediaSourceAppender } from '~/hooks/Audio'; +import { useCustomAudioRef, MediaSourceAppender, usePauseGlobalAudio } from '~/hooks/Audio'; import { useAuthContext } from '~/hooks'; import { globalAudioId } from '~/common'; import store from '~/store'; @@ -24,6 +24,7 @@ export default function StreamAudio({ index = 0 }) { const cacheTTS = useRecoilValue(store.cacheTTS); const playbackRate = useRecoilValue(store.playbackRate); + const voice = useRecoilValue(store.voice); const activeRunId = useRecoilValue(store.activeRunFamily(index)); const automaticPlayback = useRecoilValue(store.automaticPlayback); const isSubmitting = useRecoilValue(store.isSubmittingFamily(index)); @@ -34,6 +35,7 @@ export default function StreamAudio({ index = 0 }) { const [globalAudioURL, setGlobalAudioURL] = useRecoilState(store.globalAudioURLFamily(index)); const { audioRef } = useCustomAudioRef({ setIsPlaying }); + const { pauseGlobalAudio } = usePauseGlobalAudio(); const { conversationId: paramId } = useParams(); const queryParam = paramId === 'new' ? paramId : latestMessage?.conversationId ?? paramId ?? ''; @@ -90,7 +92,7 @@ export default function StreamAudio({ index = 0 }) { const response = await fetch('/api/files/tts', { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, - body: JSON.stringify({ messageId: latestMessage?.messageId, runId: activeRunId }), + body: JSON.stringify({ messageId: latestMessage?.messageId, runId: activeRunId, voice }), }); if (!response.ok) { @@ -166,6 +168,7 @@ export default function StreamAudio({ index = 0 }) { audioRunId, cacheTTS, audioRef, + voice, token, ]); @@ -180,6 +183,12 @@ export default function StreamAudio({ index = 0 }) { } }, [audioRef, globalAudioURL, playbackRate]); + useEffect(() => { + pauseGlobalAudio(); + // We only want the effect to run when the paramId changes + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [paramId]); + return (