Skip to content

Commit

Permalink
Permit generic chat message with role (#2075)
Browse files Browse the repository at this point in the history
* Permit generic chat message with role

* Address code review notes, only warn if invalid roles found

* Fix streaming

* Fix for PaLM

* Code review

* Rename to `isInstance
  • Loading branch information
dqbd authored Jul 31, 2023
1 parent db9f000 commit cd4a807
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 36 deletions.
26 changes: 21 additions & 5 deletions langchain/src/chat_models/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,38 @@ import {
BaseMessage,
ChatGeneration,
ChatGenerationChunk,
ChatMessage,
ChatResult,
MessageType,
} from "../schema/index.js";
import { getEnvironmentVariable } from "../util/env.js";
import { BaseChatModel, BaseChatModelParams } from "./base.js";

function getAnthropicPromptFromMessage(type: MessageType): string {
function extractGenericMessageCustomRole(message: ChatMessage) {
if (
message.role !== AI_PROMPT &&
message.role !== HUMAN_PROMPT &&
message.role !== ""
) {
console.warn(`Unknown message role: ${message.role}`);
}

return message.role;
}

function getAnthropicPromptFromMessage(message: BaseMessage): string {
const type = message._getType();
switch (type) {
case "ai":
return AI_PROMPT;
case "human":
return HUMAN_PROMPT;
case "system":
return "";
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
return extractGenericMessageCustomRole(message);
}
default:
throw new Error(`Unknown message type: ${type}`);
}
Expand Down Expand Up @@ -250,9 +268,7 @@ export class ChatAnthropic extends BaseChatModel implements AnthropicInput {
return (
messages
.map((message) => {
const messagePrompt = getAnthropicPromptFromMessage(
message._getType()
);
const messagePrompt = getAnthropicPromptFromMessage(message);
return `${messagePrompt} ${message.content}`;
})
.join("") + AI_PROMPT
Expand Down
22 changes: 19 additions & 3 deletions langchain/src/chat_models/baiduwenxin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import {
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
MessageType,
} from "../schema/index.js";
import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
import { getEnvironmentVariable } from "../util/env.js";
Expand Down Expand Up @@ -90,14 +90,30 @@ declare interface BaiduWenxinChatInput {
penaltyScore?: number;
}

function messageTypeToWenxinRole(type: MessageType): WenxinMessageRole {
function extractGenericMessageCustomRole(message: ChatMessage) {
if (message.role !== "assistant" && message.role !== "user") {
console.warn(`Unknown message role: ${message.role}`);
}

return message.role as WenxinMessageRole;
}

function messageToWenxinRole(message: BaseMessage): WenxinMessageRole {
const type = message._getType();
switch (type) {
case "ai":
return "assistant";
case "human":
return "user";
case "system":
throw new Error("System messages not supported");
case "function":
throw new Error("Function messages not supported");
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
return extractGenericMessageCustomRole(message);
}
default:
throw new Error(`Unknown message type: ${type}`);
}
Expand Down Expand Up @@ -263,7 +279,7 @@ export class ChatBaiduWenxin

const params = this.invocationParams();
const messagesMapped: WenxinMessage[] = messages.map((message) => ({
role: messageTypeToWenxinRole(message._getType()),
role: messageToWenxinRole(message),
content: message.text,
}));

Expand Down
27 changes: 22 additions & 5 deletions langchain/src/chat_models/googlepalm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ import { DiscussServiceClient } from "@google-ai/generativelanguage";
import type { protos } from "@google-ai/generativelanguage";
import { GoogleAuth } from "google-auth-library";
import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
import { AIMessage, BaseMessage, ChatResult } from "../schema/index.js";
import {
AIMessage,
BaseMessage,
ChatMessage,
ChatResult,
} from "../schema/index.js";
import { getEnvironmentVariable } from "../util/env.js";
import { BaseChatModel, BaseChatModelParams } from "./base.js";

Expand Down Expand Up @@ -60,6 +65,14 @@ export interface GooglePaLMChatInput extends BaseChatModelParams {
apiKey?: string;
}

function getMessageAuthor(message: BaseMessage) {
const type = message._getType();
if (ChatMessage.isInstance(message)) {
return message.role;
}
return message.name ?? type;
}

export class ChatGooglePaLM
extends BaseChatModel
implements GooglePaLMChatInput
Expand Down Expand Up @@ -175,7 +188,7 @@ export class ChatGooglePaLM
): string | undefined {
// get the first message and checks if it's a system 'system' messages
const systemMessage =
messages.length > 0 && messages[0]._getType() === "system"
messages.length > 0 && getMessageAuthor(messages[0]) === "system"
? messages[0]
: undefined;
return systemMessage?.content;
Expand All @@ -185,20 +198,24 @@ export class ChatGooglePaLM
messages: BaseMessage[]
): protos.google.ai.generativelanguage.v1beta2.IMessage[] {
// remove all 'system' messages
const nonSystemMessages = messages.filter((m) => m._getType() !== "system");
const nonSystemMessages = messages.filter(
(m) => getMessageAuthor(m) !== "system"
);

// requires alternate human & ai messages. Throw error if two messages are consecutive
nonSystemMessages.forEach((msg, index) => {
if (index < 1) return;
if (msg._getType() === nonSystemMessages[index - 1]._getType()) {
if (
getMessageAuthor(msg) === getMessageAuthor(nonSystemMessages[index - 1])
) {
throw new Error(
`Google PaLM requires alternate messages between authors`
);
}
});

return nonSystemMessages.map((m) => ({
author: m.name ?? m._getType(),
author: getMessageAuthor(m),
content: m.content,
citationMetadata: {
citationSources: m.additional_kwargs.citationSources as
Expand Down
52 changes: 40 additions & 12 deletions langchain/src/chat_models/googlevertexai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import {
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
LLMResult,
MessageType,
} from "../schema/index.js";
import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js";
import {
Expand Down Expand Up @@ -54,11 +54,25 @@ export class GoogleVertexAIChatMessage {
this.name = fields.name;
}

static extractGenericMessageCustomRole(message: ChatMessage) {
if (
message.role !== "system" &&
message.role !== "bot" &&
message.role !== "user" &&
message.role !== "context"
) {
console.warn(`Unknown message role: ${message.role}`);
}

return message.role as GoogleVertexAIChatAuthor;
}

static mapMessageTypeToVertexChatAuthor(
baseMessageType: MessageType,
message: BaseMessage,
model: string
): GoogleVertexAIChatAuthor {
switch (baseMessageType) {
const type = message._getType();
switch (type) {
case "ai":
return model.startsWith("codechat-") ? "system" : "bot";
case "human":
Expand All @@ -67,17 +81,22 @@ export class GoogleVertexAIChatMessage {
throw new Error(
`System messages are only supported as the first passed message for Google Vertex AI.`
);
default:
throw new Error(
`Unknown / unsupported message type: ${baseMessageType}`
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
return GoogleVertexAIChatMessage.extractGenericMessageCustomRole(
message
);
}
default:
throw new Error(`Unknown / unsupported message type: ${message}`);
}
}

static fromChatMessage(message: BaseMessage, model: string) {
return new GoogleVertexAIChatMessage({
author: GoogleVertexAIChatMessage.mapMessageTypeToVertexChatAuthor(
message._getType(),
message,
model
),
content: message.content,
Expand Down Expand Up @@ -211,16 +230,25 @@ export class ChatGoogleVertexAI
);
}
const vertexChatMessages = conversationMessages.map((baseMessage, i) => {
const currMessage = GoogleVertexAIChatMessage.fromChatMessage(
baseMessage,
this.model
);
const prevMessage =
i > 0
? GoogleVertexAIChatMessage.fromChatMessage(
conversationMessages[i - 1],
this.model
)
: null;

// https://cloud.google.com/vertex-ai/docs/generative-ai/chat/chat-prompts#messages
if (
i > 0 &&
baseMessage._getType() === conversationMessages[i - 1]._getType()
) {
if (prevMessage && currMessage.author === prevMessage.author) {
throw new Error(
`Google Vertex AI requires AI and human messages to alternate.`
);
}
return GoogleVertexAIChatMessage.fromChatMessage(baseMessage, this.model);
return currMessage;
});

const examples = this.examples.map((example) => ({
Expand Down
34 changes: 25 additions & 9 deletions langchain/src/chat_models/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import {
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
MessageType,
SystemMessage,
SystemMessageChunk,
} from "../schema/index.js";
Expand All @@ -56,9 +55,23 @@ interface OpenAILLMOutput {
tokenUsage: TokenUsage;
}

function messageTypeToOpenAIRole(
type: MessageType
function extractGenericMessageCustomRole(message: ChatMessage) {
if (
message.role !== "system" &&
message.role !== "assistant" &&
message.role !== "user" &&
message.role !== "function"
) {
console.warn(`Unknown message role: ${message.role}`);
}

return message.role as ChatCompletionResponseMessageRoleEnum;
}

function messageToOpenAIRole(
message: BaseMessage
): ChatCompletionResponseMessageRoleEnum {
const type = message._getType();
switch (type) {
case "system":
return "system";
Expand All @@ -68,6 +81,11 @@ function messageTypeToOpenAIRole(
return "user";
case "function":
return "function";
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
return extractGenericMessageCustomRole(message);
}
default:
throw new Error(`Unknown message type: ${type}`);
}
Expand Down Expand Up @@ -340,7 +358,7 @@ export class ChatOpenAI
): AsyncGenerator<ChatGenerationChunk> {
const messagesMapped: ChatCompletionRequestMessage[] = messages.map(
(message) => ({
role: messageTypeToOpenAIRole(message._getType()),
role: messageToOpenAIRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs
Expand Down Expand Up @@ -455,7 +473,7 @@ export class ChatOpenAI
const params = this.invocationParams(options);
const messagesMapped: ChatCompletionRequestMessage[] = messages.map(
(message) => ({
role: messageTypeToOpenAIRole(message._getType()),
role: messageToOpenAIRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs
Expand Down Expand Up @@ -661,9 +679,7 @@ export class ChatOpenAI
const countPerMessage = await Promise.all(
messages.map(async (message) => {
const textCount = await this.getNumTokens(message.content);
const roleCount = await this.getNumTokens(
messageTypeToOpenAIRole(message._getType())
);
const roleCount = await this.getNumTokens(messageToOpenAIRole(message));
const nameCount =
message.name !== undefined
? tokensPerName + (await this.getNumTokens(message.name))
Expand Down Expand Up @@ -865,7 +881,7 @@ export class PromptLayerChatOpenAI extends ChatOpenAI {
const parsedResp = [
{
content: generation.text,
role: messageTypeToOpenAIRole(generation.message._getType()),
role: messageToOpenAIRole(generation.message),
},
];

Expand Down
13 changes: 12 additions & 1 deletion langchain/src/chat_models/tests/chatanthropic.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { expect, test } from "@jest/globals";
import { HumanMessage } from "../../schema/index.js";
import { HUMAN_PROMPT } from "@anthropic-ai/sdk";
import { ChatMessage, HumanMessage } from "../../schema/index.js";
import { ChatPromptValue } from "../../prompts/chat.js";
import {
PromptTemplate,
Expand Down Expand Up @@ -213,6 +214,16 @@ test("ChatAnthropic, Claude V2", async () => {
console.log(responseA.generations);
});

test("ChatAnthropic with specific roles in ChatMessage", async () => {
const chat = new ChatAnthropic({
modelName: "claude-instant-v1",
maxTokensToSample: 10,
});
const user_message = new ChatMessage("Hello!", HUMAN_PROMPT);
const res = await chat.call([user_message]);
console.log({ res });
});

test("Test ChatAnthropic stream method", async () => {
const model = new ChatAnthropic({
maxTokensToSample: 50,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { test } from "@jest/globals";
import { HumanMessage } from "../../schema/index.js";
import { ChatMessage, HumanMessage } from "../../schema/index.js";
import {
PromptTemplate,
ChatPromptTemplate,
Expand All @@ -26,6 +26,12 @@ test("Test ChatGoogleVertexAI generate", async () => {
console.log(JSON.stringify(res, null, 2));
});

test("Google code messages with custom messages", async () => {
const chat = new ChatGoogleVertexAI();
const res = await chat.call([new ChatMessage("Hello!", "user")]);
console.log(JSON.stringify(res, null, 2));
});

test("ChatGoogleVertexAI, prompt templates", async () => {
const chat = new ChatGoogleVertexAI();

Expand Down
Loading

1 comment on commit cd4a807

@vercel
Copy link

@vercel vercel bot commented on cd4a807 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.