From cd4a807f0b742012acc26b6de8c44670eea86ef9 Mon Sep 17 00:00:00 2001 From: David Duong Date: Mon, 31 Jul 2023 21:36:18 +0200 Subject: [PATCH] Permit generic chat message with role (#2075) * Permit generic chat message with role * Address code review notes, only warn if invalid roles found * Fix streaming * Fix for PaLM * Code review * Rename to `isInstance --- langchain/src/chat_models/anthropic.ts | 26 ++++++++-- langchain/src/chat_models/baiduwenxin.ts | 22 ++++++-- langchain/src/chat_models/googlepalm.ts | 27 ++++++++-- langchain/src/chat_models/googlevertexai.ts | 52 ++++++++++++++----- langchain/src/chat_models/openai.ts | 34 ++++++++---- .../tests/chatanthropic.int.test.ts | 13 ++++- .../tests/chatgooglevertexai.int.test.ts | 8 ++- .../chat_models/tests/chatopenai.int.test.ts | 12 +++++ langchain/src/schema/index.ts | 4 ++ 9 files changed, 162 insertions(+), 36 deletions(-) diff --git a/langchain/src/chat_models/anthropic.ts b/langchain/src/chat_models/anthropic.ts index b9ed87d40b33..d14e692758a8 100644 --- a/langchain/src/chat_models/anthropic.ts +++ b/langchain/src/chat_models/anthropic.ts @@ -8,13 +8,26 @@ import { BaseMessage, ChatGeneration, ChatGenerationChunk, + ChatMessage, ChatResult, - MessageType, } from "../schema/index.js"; import { getEnvironmentVariable } from "../util/env.js"; import { BaseChatModel, BaseChatModelParams } from "./base.js"; -function getAnthropicPromptFromMessage(type: MessageType): string { +function extractGenericMessageCustomRole(message: ChatMessage) { + if ( + message.role !== AI_PROMPT && + message.role !== HUMAN_PROMPT && + message.role !== "" + ) { + console.warn(`Unknown message role: ${message.role}`); + } + + return message.role; +} + +function getAnthropicPromptFromMessage(message: BaseMessage): string { + const type = message._getType(); switch (type) { case "ai": return AI_PROMPT; @@ -22,6 +35,11 @@ function getAnthropicPromptFromMessage(type: MessageType): string { return HUMAN_PROMPT; case "system": return ""; + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return extractGenericMessageCustomRole(message); + } default: throw new Error(`Unknown message type: ${type}`); } @@ -250,9 +268,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput { return ( messages .map((message) => { - const messagePrompt = getAnthropicPromptFromMessage( - message._getType() - ); + const messagePrompt = getAnthropicPromptFromMessage(message); return `${messagePrompt} ${message.content}`; }) .join("") + AI_PROMPT diff --git a/langchain/src/chat_models/baiduwenxin.ts b/langchain/src/chat_models/baiduwenxin.ts index 3d1233d9af9d..75b761b15298 100644 --- a/langchain/src/chat_models/baiduwenxin.ts +++ b/langchain/src/chat_models/baiduwenxin.ts @@ -3,8 +3,8 @@ import { AIMessage, BaseMessage, ChatGeneration, + ChatMessage, ChatResult, - MessageType, } from "../schema/index.js"; import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; import { getEnvironmentVariable } from "../util/env.js"; @@ -90,7 +90,16 @@ declare interface BaiduWenxinChatInput { penaltyScore?: number; } -function messageTypeToWenxinRole(type: MessageType): WenxinMessageRole { +function extractGenericMessageCustomRole(message: ChatMessage) { + if (message.role !== "assistant" && message.role !== "user") { + console.warn(`Unknown message role: ${message.role}`); + } + + return message.role as WenxinMessageRole; +} + +function messageToWenxinRole(message: BaseMessage): WenxinMessageRole { + const type = message._getType(); switch (type) { case "ai": return "assistant"; @@ -98,6 +107,13 @@ function messageTypeToWenxinRole(type: MessageType): WenxinMessageRole { return "user"; case "system": throw new Error("System messages not supported"); + case "function": + throw new Error("Function messages not supported"); + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return extractGenericMessageCustomRole(message); + } default: throw new Error(`Unknown message type: ${type}`); } @@ -263,7 +279,7 @@ export class ChatBaiduWenxin const params = this.invocationParams(); const messagesMapped: WenxinMessage[] = messages.map((message) => ({ - role: messageTypeToWenxinRole(message._getType()), + role: messageToWenxinRole(message), content: message.text, })); diff --git a/langchain/src/chat_models/googlepalm.ts b/langchain/src/chat_models/googlepalm.ts index 6e55985eb28b..85b7cf9a12d4 100644 --- a/langchain/src/chat_models/googlepalm.ts +++ b/langchain/src/chat_models/googlepalm.ts @@ -2,7 +2,12 @@ import { DiscussServiceClient } from "@google-ai/generativelanguage"; import type { protos } from "@google-ai/generativelanguage"; import { GoogleAuth } from "google-auth-library"; import { CallbackManagerForLLMRun } from "../callbacks/manager.js"; -import { AIMessage, BaseMessage, ChatResult } from "../schema/index.js"; +import { + AIMessage, + BaseMessage, + ChatMessage, + ChatResult, +} from "../schema/index.js"; import { getEnvironmentVariable } from "../util/env.js"; import { BaseChatModel, BaseChatModelParams } from "./base.js"; @@ -60,6 +65,14 @@ export interface GooglePaLMChatInput extends BaseChatModelParams { apiKey?: string; } +function getMessageAuthor(message: BaseMessage) { + const type = message._getType(); + if (ChatMessage.isInstance(message)) { + return message.role; + } + return message.name ?? type; +} + export class ChatGooglePaLM extends BaseChatModel implements GooglePaLMChatInput @@ -175,7 +188,7 @@ export class ChatGooglePaLM ): string | undefined { // get the first message and checks if it's a system 'system' messages const systemMessage = - messages.length > 0 && messages[0]._getType() === "system" + messages.length > 0 && getMessageAuthor(messages[0]) === "system" ? messages[0] : undefined; return systemMessage?.content; @@ -185,12 +198,16 @@ export class ChatGooglePaLM messages: BaseMessage[] ): protos.google.ai.generativelanguage.v1beta2.IMessage[] { // remove all 'system' messages - const nonSystemMessages = messages.filter((m) => m._getType() !== "system"); + const nonSystemMessages = messages.filter( + (m) => getMessageAuthor(m) !== "system" + ); // requires alternate human & ai messages. Throw error if two messages are consecutive nonSystemMessages.forEach((msg, index) => { if (index < 1) return; - if (msg._getType() === nonSystemMessages[index - 1]._getType()) { + if ( + getMessageAuthor(msg) === getMessageAuthor(nonSystemMessages[index - 1]) + ) { throw new Error( `Google PaLM requires alternate messages between authors` ); @@ -198,7 +215,7 @@ export class ChatGooglePaLM }); return nonSystemMessages.map((m) => ({ - author: m.name ?? m._getType(), + author: getMessageAuthor(m), content: m.content, citationMetadata: { citationSources: m.additional_kwargs.citationSources as diff --git a/langchain/src/chat_models/googlevertexai.ts b/langchain/src/chat_models/googlevertexai.ts index e67bd9b2f530..ce323d0cf823 100644 --- a/langchain/src/chat_models/googlevertexai.ts +++ b/langchain/src/chat_models/googlevertexai.ts @@ -3,9 +3,9 @@ import { AIMessage, BaseMessage, ChatGeneration, + ChatMessage, ChatResult, LLMResult, - MessageType, } from "../schema/index.js"; import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js"; import { @@ -54,11 +54,25 @@ export class GoogleVertexAIChatMessage { this.name = fields.name; } + static extractGenericMessageCustomRole(message: ChatMessage) { + if ( + message.role !== "system" && + message.role !== "bot" && + message.role !== "user" && + message.role !== "context" + ) { + console.warn(`Unknown message role: ${message.role}`); + } + + return message.role as GoogleVertexAIChatAuthor; + } + static mapMessageTypeToVertexChatAuthor( - baseMessageType: MessageType, + message: BaseMessage, model: string ): GoogleVertexAIChatAuthor { - switch (baseMessageType) { + const type = message._getType(); + switch (type) { case "ai": return model.startsWith("codechat-") ? "system" : "bot"; case "human": @@ -67,17 +81,22 @@ export class GoogleVertexAIChatMessage { throw new Error( `System messages are only supported as the first passed message for Google Vertex AI.` ); - default: - throw new Error( - `Unknown / unsupported message type: ${baseMessageType}` + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return GoogleVertexAIChatMessage.extractGenericMessageCustomRole( + message ); + } + default: + throw new Error(`Unknown / unsupported message type: ${message}`); } } static fromChatMessage(message: BaseMessage, model: string) { return new GoogleVertexAIChatMessage({ author: GoogleVertexAIChatMessage.mapMessageTypeToVertexChatAuthor( - message._getType(), + message, model ), content: message.content, @@ -211,16 +230,25 @@ export class ChatGoogleVertexAI ); } const vertexChatMessages = conversationMessages.map((baseMessage, i) => { + const currMessage = GoogleVertexAIChatMessage.fromChatMessage( + baseMessage, + this.model + ); + const prevMessage = + i > 0 + ? GoogleVertexAIChatMessage.fromChatMessage( + conversationMessages[i - 1], + this.model + ) + : null; + // https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts#messages - if ( - i > 0 && - baseMessage._getType() === conversationMessages[i - 1]._getType() - ) { + if (prevMessage && currMessage.author === prevMessage.author) { throw new Error( `Google Vertex AI requires AI and human messages to alternate.` ); } - return GoogleVertexAIChatMessage.fromChatMessage(baseMessage, this.model); + return currMessage; }); const examples = this.examples.map((example) => ({ diff --git a/langchain/src/chat_models/openai.ts b/langchain/src/chat_models/openai.ts index bbabaeae7fdc..86446490f37b 100644 --- a/langchain/src/chat_models/openai.ts +++ b/langchain/src/chat_models/openai.ts @@ -32,7 +32,6 @@ import { FunctionMessageChunk, HumanMessage, HumanMessageChunk, - MessageType, SystemMessage, SystemMessageChunk, } from "../schema/index.js"; @@ -56,9 +55,23 @@ interface OpenAILLMOutput { tokenUsage: TokenUsage; } -function messageTypeToOpenAIRole( - type: MessageType +function extractGenericMessageCustomRole(message: ChatMessage) { + if ( + message.role !== "system" && + message.role !== "assistant" && + message.role !== "user" && + message.role !== "function" + ) { + console.warn(`Unknown message role: ${message.role}`); + } + + return message.role as ChatCompletionResponseMessageRoleEnum; +} + +function messageToOpenAIRole( + message: BaseMessage ): ChatCompletionResponseMessageRoleEnum { + const type = message._getType(); switch (type) { case "system": return "system"; @@ -68,6 +81,11 @@ function messageTypeToOpenAIRole( return "user"; case "function": return "function"; + case "generic": { + if (!ChatMessage.isInstance(message)) + throw new Error("Invalid generic chat message"); + return extractGenericMessageCustomRole(message); + } default: throw new Error(`Unknown message type: ${type}`); } @@ -340,7 +358,7 @@ export class ChatOpenAI ): AsyncGenerator { const messagesMapped: ChatCompletionRequestMessage[] = messages.map( (message) => ({ - role: messageTypeToOpenAIRole(message._getType()), + role: messageToOpenAIRole(message), content: message.content, name: message.name, function_call: message.additional_kwargs @@ -455,7 +473,7 @@ export class ChatOpenAI const params = this.invocationParams(options); const messagesMapped: ChatCompletionRequestMessage[] = messages.map( (message) => ({ - role: messageTypeToOpenAIRole(message._getType()), + role: messageToOpenAIRole(message), content: message.content, name: message.name, function_call: message.additional_kwargs @@ -661,9 +679,7 @@ export class ChatOpenAI const countPerMessage = await Promise.all( messages.map(async (message) => { const textCount = await this.getNumTokens(message.content); - const roleCount = await this.getNumTokens( - messageTypeToOpenAIRole(message._getType()) - ); + const roleCount = await this.getNumTokens(messageToOpenAIRole(message)); const nameCount = message.name !== undefined ? tokensPerName + (await this.getNumTokens(message.name)) @@ -865,7 +881,7 @@ export class PromptLayerChatOpenAI extends ChatOpenAI { const parsedResp = [ { content: generation.text, - role: messageTypeToOpenAIRole(generation.message._getType()), + role: messageToOpenAIRole(generation.message), }, ]; diff --git a/langchain/src/chat_models/tests/chatanthropic.int.test.ts b/langchain/src/chat_models/tests/chatanthropic.int.test.ts index f76bd975aa67..bf0f1b6923be 100644 --- a/langchain/src/chat_models/tests/chatanthropic.int.test.ts +++ b/langchain/src/chat_models/tests/chatanthropic.int.test.ts @@ -1,5 +1,6 @@ import { expect, test } from "@jest/globals"; -import { HumanMessage } from "../../schema/index.js"; +import { HUMAN_PROMPT } from "@anthropic-ai/sdk"; +import { ChatMessage, HumanMessage } from "../../schema/index.js"; import { ChatPromptValue } from "../../prompts/chat.js"; import { PromptTemplate, @@ -213,6 +214,16 @@ test("ChatAnthropic, Claude V2", async () => { console.log(responseA.generations); }); +test("ChatAnthropic with specific roles in ChatMessage", async () => { + const chat = new ChatAnthropic({ + modelName: "claude-instant-v1", + maxTokensToSample: 10, + }); + const user_message = new ChatMessage("Hello!", HUMAN_PROMPT); + const res = await chat.call([user_message]); + console.log({ res }); +}); + test("Test ChatAnthropic stream method", async () => { const model = new ChatAnthropic({ maxTokensToSample: 50, diff --git a/langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts b/langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts index bb92663acd26..ef17a75b5d27 100644 --- a/langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts +++ b/langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts @@ -1,5 +1,5 @@ import { test } from "@jest/globals"; -import { HumanMessage } from "../../schema/index.js"; +import { ChatMessage, HumanMessage } from "../../schema/index.js"; import { PromptTemplate, ChatPromptTemplate, @@ -26,6 +26,12 @@ test("Test ChatGoogleVertexAI generate", async () => { console.log(JSON.stringify(res, null, 2)); }); +test("Google code messages with custom messages", async () => { + const chat = new ChatGoogleVertexAI(); + const res = await chat.call([new ChatMessage("Hello!", "user")]); + console.log(JSON.stringify(res, null, 2)); +}); + test("ChatGoogleVertexAI, prompt templates", async () => { const chat = new ChatGoogleVertexAI(); diff --git a/langchain/src/chat_models/tests/chatopenai.int.test.ts b/langchain/src/chat_models/tests/chatopenai.int.test.ts index aaf8f42f1458..6e89ef0ac9f3 100644 --- a/langchain/src/chat_models/tests/chatopenai.int.test.ts +++ b/langchain/src/chat_models/tests/chatopenai.int.test.ts @@ -2,6 +2,7 @@ import { test, expect } from "@jest/globals"; import { ChatOpenAI } from "../openai.js"; import { BaseMessage, + ChatMessage, ChatGeneration, HumanMessage, LLMResult, @@ -334,6 +335,17 @@ test("getNumTokensFromMessages gpt-4-0314 model for sample input", async () => { expect(totalCount).toBe(129); }); +test("Test OpenAI with specific roles in ChatMessage", async () => { + const chat = new ChatOpenAI({ modelName: "gpt-3.5-turbo", maxTokens: 10 }); + const system_message = new ChatMessage( + "You are to chat with a user.", + "system" + ); + const user_message = new ChatMessage("Hello!", "user"); + const res = await chat.call([system_message, user_message]); + console.log({ res }); +}); + test("Test ChatOpenAI stream method", async () => { const model = new ChatOpenAI({ maxTokens: 50, modelName: "gpt-3.5-turbo" }); const stream = await model.stream("Print hello world."); diff --git a/langchain/src/schema/index.ts b/langchain/src/schema/index.ts index db5a63230862..29da2a185ff3 100644 --- a/langchain/src/schema/index.ts +++ b/langchain/src/schema/index.ts @@ -348,6 +348,10 @@ export class ChatMessage _getType(): MessageType { return "generic"; } + + static isInstance(message: BaseMessage): message is ChatMessage { + return message._getType() === "generic"; + } } export class ChatMessageChunk extends BaseMessageChunk {