Skip to content

Commit

Permalink
Merge branch 'main' into feature/chats-import
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisPalnitsky authored Apr 17, 2024
2 parents 30e44a9 + 3c184e9 commit 8fdfb11
Show file tree
Hide file tree
Showing 37 changed files with 1,044 additions and 262 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=

# Google
#-----------------
GOOGLE_API_KEY=
GOOGLE_SEARCH_API_KEY=
GOOGLE_CSE_ID=

# SerpAPI
Expand Down
135 changes: 118 additions & 17 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
const { google } = require('googleapis');
const { Agent, ProxyAgent } = require('undici');
const { GoogleVertexAI } = require('langchain/llms/googlevertexai');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { GoogleVertexAI } = require('@langchain/community/llms/googlevertexai');
const { ChatGoogleVertexAI } = require('langchain/chat_models/googlevertexai');
const { AIMessage, HumanMessage, SystemMessage } = require('langchain/schema');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
Expand All @@ -10,6 +12,7 @@ const {
getResponseSender,
endpointSettings,
EModelEndpoint,
VisionModes,
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
Expand Down Expand Up @@ -126,7 +129,7 @@ class GoogleClient extends BaseClient {

this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments));

// TODO: as of 12/14/23, only gemini models are "Generative AI" models provided by Google
/** @type {boolean} Whether using a "GenerativeAI" Model */
this.isGenerativeModel = this.modelOptions.model.includes('gemini');
const { isGenerativeModel } = this;
this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat');
Expand Down Expand Up @@ -247,6 +250,40 @@ class GoogleClient extends BaseClient {
})).bind(this);
}

/**
* Formats messages for generative AI
* @param {TMessage[]} messages
* @returns
*/
async formatGenerativeMessages(messages) {
const formattedMessages = [];
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative);
this.options.attachments = files;
messages[messages.length - 1] = latestMessage;

for (const _message of messages) {
const role = _message.isCreatedByUser ? this.userLabel : this.modelLabel;
const parts = [];
parts.push({ text: _message.text });
if (!_message.image_urls?.length) {
formattedMessages.push({ role, parts });
continue;
}

for (const images of _message.image_urls) {
if (images.inlineData) {
parts.push({ inlineData: images.inlineData });
}
}

formattedMessages.push({ role, parts });
}

return formattedMessages;
}

/**
*
* Adds image URLs to the message object and returns the files
Expand All @@ -255,17 +292,23 @@ class GoogleClient extends BaseClient {
* @param {MongoFile[]} files
* @returns {Promise<MongoFile[]>}
*/
async addImageURLs(message, attachments) {
async addImageURLs(message, attachments, mode = '') {
const { files, image_urls } = await encodeAndFormat(
this.options.req,
attachments,
EModelEndpoint.google,
mode,
);
message.image_urls = image_urls.length ? image_urls : undefined;
return files;
}

async buildVisionMessages(messages = [], parentMessageId) {
/**
* Builds the augmented prompt for attachments
* TODO: Add File API Support
* @param {TMessage[]} messages
*/
async buildAugmentedPrompt(messages = []) {
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
this.contextHandlers = createContextHandlers(this.options.req, latestMessage.text);
Expand All @@ -281,6 +324,12 @@ class GoogleClient extends BaseClient {
this.augmentedPrompt = await this.contextHandlers.createContext();
this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix;
}
}

async buildVisionMessages(messages = [], parentMessageId) {
const attachments = await this.options.attachments;
const latestMessage = { ...messages[messages.length - 1] };
await this.buildAugmentedPrompt(messages);

const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);

Expand All @@ -301,15 +350,26 @@ class GoogleClient extends BaseClient {
return { prompt: payload };
}

/** @param {TMessage[]} [messages=[]] */
async buildGenerativeMessages(messages = []) {
this.userLabel = 'user';
this.modelLabel = 'model';
const promises = [];
promises.push(await this.formatGenerativeMessages(messages));
promises.push(this.buildAugmentedPrompt(messages));
const [formattedMessages] = await Promise.all(promises);
return { prompt: formattedMessages };
}

async buildMessages(messages = [], parentMessageId) {
if (!this.isGenerativeModel && !this.project_id) {
throw new Error(
'[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)',
);
} else if (this.isGenerativeModel && (!this.apiKey || this.apiKey === 'user_provided')) {
throw new Error(
'[GoogleClient] an API Key is required for Gemini models (Generative Language API)',
);
}

if (!this.project_id && this.modelOptions.model.includes('1.5')) {
return await this.buildGenerativeMessages(messages);
}

if (this.options.attachments && this.isGenerativeModel) {
Expand Down Expand Up @@ -526,13 +586,24 @@ class GoogleClient extends BaseClient {
}

createLLM(clientOptions) {
if (this.isGenerativeModel) {
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
model,
},
{ apiVersion: 'v1beta' },
);
}

return this.isTextModel
? new GoogleVertexAI(clientOptions)
: new ChatGoogleVertexAI(clientOptions);
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}

async getCompletion(_payload, options = {}) {
Expand All @@ -544,7 +615,7 @@ class GoogleClient extends BaseClient {

let clientOptions = { ...parameters, maxRetries: 2 };

if (!this.isGenerativeModel) {
if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
Expand All @@ -557,7 +628,7 @@ class GoogleClient extends BaseClient {
clientOptions = { ...clientOptions, ...this.modelOptions };
}

if (this.isGenerativeModel) {
if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}
Expand Down Expand Up @@ -588,16 +659,46 @@ class GoogleClient extends BaseClient {
messages.unshift(new SystemMessage(context));
}

const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};

if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}

const result = await client.generateContentStream(requestOptions);
for await (const chunk of result.stream) {
const chunkText = chunk.text();
this.generateTextStream(chunkText, onProgress, {
delay: 12,
});
reply += chunkText;
}
return reply;
}

const stream = await model.stream(messages, {
signal: abortController.signal,
timeout: 7000,
});

for await (const chunk of stream) {
await this.generateTextStream(chunk?.content ?? chunk, onProgress, {
const chunkText = chunk?.content ?? chunk;
this.generateTextStream(chunkText, onProgress, {
delay: this.isGenerativeModel ? 12 : 8,
});
reply += chunk?.content ?? chunk;
reply += chunkText;
}

return reply;
Expand Down
2 changes: 1 addition & 1 deletion api/app/clients/prompts/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module.exports = {
...handleInputs,
...instructions,
...titlePrompts,
truncateText,
...truncateText,
createVisionPrompt,
createContextHandlers,
};
38 changes: 34 additions & 4 deletions api/app/clients/prompts/truncateText.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
const MAX_CHAR = 255;

function truncateText(text) {
if (text.length > MAX_CHAR) {
return `${text.slice(0, MAX_CHAR)}... [text truncated for brevity]`;
/**
* Truncates a given text to a specified maximum length, appending ellipsis and a notification
* if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text if the original text length exceeds maxLength, otherwise returns the original text.
*/
function truncateText(text, maxLength = MAX_CHAR) {
if (text.length > maxLength) {
return `${text.slice(0, maxLength)}... [text truncated for brevity]`;
}
return text;
}

module.exports = truncateText;
/**
* Truncates a given text to a specified maximum length by showing the first half and the last half of the text,
* separated by ellipsis. This method ensures the output does not exceed the maximum length, including the addition
* of ellipsis and notification if the original text exceeds the maximum length.
*
* @param {string} text - The text to be truncated.
* @param {number} [maxLength=MAX_CHAR] - The maximum length of the output text after truncation. Defaults to MAX_CHAR.
* @returns {string} The truncated text showing the first half and the last half, or the original text if it does not exceed maxLength.
*/
function smartTruncateText(text, maxLength = MAX_CHAR) {
const ellipsis = '...';
const notification = ' [text truncated for brevity]';
const halfMaxLength = Math.floor((maxLength - ellipsis.length - notification.length) / 2);

if (text.length > maxLength) {
const startLastHalf = text.length - halfMaxLength;
return `${text.slice(0, halfMaxLength)}${ellipsis}${text.slice(startLastHalf)}${notification}`;
}

return text;
}

module.exports = { truncateText, smartTruncateText };
2 changes: 1 addition & 1 deletion api/app/clients/tools/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"description": "This is your Google Custom Search Engine ID. For instructions on how to obtain this, see <a href='https://github.com/danny-avila/LibreChat/blob/main/docs/features/plugins/google_search.md'>Our Docs</a>."
},
{
"authField": "GOOGLE_API_KEY",
"authField": "GOOGLE_SEARCH_API_KEY",
"label": "Google API Key",
"description": "This is your Google Custom Search API Key. For instructions on how to obtain this, see <a href='https://github.com/danny-avila/LibreChat/blob/main/docs/features/plugins/google_search.md'>Our Docs</a>."
}
Expand Down
2 changes: 1 addition & 1 deletion api/app/clients/tools/structured/GoogleSearch.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class GoogleSearchResults extends Tool {

constructor(fields = {}) {
super(fields);
this.envVarApiKey = 'GOOGLE_API_KEY';
this.envVarApiKey = 'GOOGLE_SEARCH_API_KEY';
this.envVarSearchEngineId = 'GOOGLE_CSE_ID';
this.override = fields.override ?? false;
this.apiKey = fields.apiKey ?? getEnvironmentVariable(this.envVarApiKey);
Expand Down
4 changes: 4 additions & 0 deletions api/models/tx.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ const tokenValues = {
/* cohere doesn't have rates for the older command models,
so this was from https://artificialanalysis.ai/models/command-light/providers */
command: { prompt: 0.38, completion: 0.38 },
// 'gemini-1.5': { prompt: 7, completion: 21 }, // May 2nd, 2024 pricing
// 'gemini': { prompt: 0.5, completion: 1.5 }, // May 2nd, 2024 pricing
'gemini-1.5': { prompt: 0, completion: 0 }, // currently free
gemini: { prompt: 0, completion: 0 }, // currently free
};

/**
Expand Down
6 changes: 4 additions & 2 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
"dependencies": {
"@anthropic-ai/sdk": "^0.16.1",
"@azure/search-documents": "^12.0.0",
"@google/generative-ai": "^0.5.0",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.0.17",
"@langchain/google-genai": "^0.0.8",
"@langchain/community": "^0.0.46",
"@langchain/google-genai": "^0.0.11",
"@langchain/google-vertexai": "^0.0.5",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
Expand Down
6 changes: 3 additions & 3 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
Expand Down Expand Up @@ -48,7 +48,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {

try {
const { client } = await initializeClient({ req, res, endpointOption });

const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
Expand All @@ -59,7 +59,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: client.modelOptions.model,
unfinished: true,
unfinished,
error: false,
user,
});
Expand Down
5 changes: 3 additions & 2 deletions api/server/controllers/EditController.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const throttle = require('lodash/throttle');
const { getResponseSender } = require('librechat-data-provider');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
Expand Down Expand Up @@ -48,6 +48,7 @@ const EditController = async (req, res, next, initializeClient) => {
}
};

const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
generation,
onProgress: throttle(
Expand All @@ -59,7 +60,7 @@ const EditController = async (req, res, next, initializeClient) => {
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: endpointOption.modelOptions.model,
unfinished: true,
unfinished,
isEdited: true,
error: false,
user,
Expand Down
Loading

0 comments on commit 8fdfb11

Please sign in to comment.