Skip to content

Commit

Permalink
Add support for codechat- models. Which, of course, take different pa…
Browse files Browse the repository at this point in the history
…rameters and have different names for some of the roles. (#1728)
  • Loading branch information
afirstenberg authored Jun 26, 2023
1 parent 722d8e5 commit 516e3cd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
34 changes: 22 additions & 12 deletions langchain/src/chat_models/googlevertexai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ interface GoogleVertexAIChatExample {
output: GoogleVertexAIChatMessage;
}

export type GoogleVertexAIChatAuthor = "user" | "bot" | "context";
export type GoogleVertexAIChatAuthor =
| "user" // Represents the human for Code and CodeChat models
| "bot" // Represents the AI for Code models
| "system" // Represents the AI for CodeChat models
| "context"; // Represents contextual instructions

export type GoogleVertexAIChatMessageFields = {
author?: GoogleVertexAIChatAuthor;
Expand All @@ -50,30 +54,30 @@ export class GoogleVertexAIChatMessage {
}

static mapMessageTypeToVertexChatAuthor(
baseMessageType: MessageType
baseMessageType: MessageType,
model: string
): GoogleVertexAIChatAuthor {
switch (baseMessageType) {
case "ai":
return "bot";
return model.startsWith("codechat-") ? "system" : "bot";
case "human":
return "user";
case "system":
throw new Error(
`System messages are only supported as the first passed message for Google Vertex AI.`
);
case "generic":
default:
throw new Error(
`Generic messages are not supported by Google Vertex AI.`
`Unknown / unsupported message type: ${baseMessageType}`
);
default:
throw new Error(`Unknown message type: ${baseMessageType}`);
}
}

static fromChatMessage(message: BaseChatMessage) {
static fromChatMessage(message: BaseChatMessage, model: string) {
return new GoogleVertexAIChatMessage({
author: GoogleVertexAIChatMessage.mapMessageTypeToVertexChatAuthor(
message._getType()
message._getType(),
model
),
content: message.text,
});
Expand Down Expand Up @@ -215,12 +219,18 @@ export class ChatGoogleVertexAI
`Google Vertex AI requires AI and human messages to alternate.`
);
}
return GoogleVertexAIChatMessage.fromChatMessage(baseMessage);
return GoogleVertexAIChatMessage.fromChatMessage(baseMessage, this.model);
});

const examples = this.examples.map((example) => ({
input: GoogleVertexAIChatMessage.fromChatMessage(example.input),
output: GoogleVertexAIChatMessage.fromChatMessage(example.output),
input: GoogleVertexAIChatMessage.fromChatMessage(
example.input,
this.model
),
output: GoogleVertexAIChatMessage.fromChatMessage(
example.output,
this.model
),
}));

const instance: GoogleVertexAIChatInstance = {
Expand Down
29 changes: 24 additions & 5 deletions langchain/src/chat_models/tests/chatgooglevertexai.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ import { ConversationChain } from "../../chains/conversation.js";
import { BufferMemory } from "../../memory/buffer_memory.js";
import { ChatGoogleVertexAI } from "../googlevertexai.js";

test.skip("Test ChatGoogleVertexAI", async () => {
test("Test ChatGoogleVertexAI", async () => {
const chat = new ChatGoogleVertexAI();
const message = new HumanChatMessage("Hello!");
const res = await chat.call([message]);
console.log({ res });
});

test.skip("Test ChatGoogleVertexAI generate", async () => {
test("Test ChatGoogleVertexAI generate", async () => {
const chat = new ChatGoogleVertexAI();
const message = new HumanChatMessage("Hello!");
const res = await chat.generate([[message]]);
console.log(JSON.stringify(res, null, 2));
});

test.skip("ChatGoogleVertexAI, prompt templates", async () => {
test("ChatGoogleVertexAI, prompt templates", async () => {
const chat = new ChatGoogleVertexAI();

// PaLM doesn't support translation yet
Expand All @@ -49,7 +49,7 @@ test.skip("ChatGoogleVertexAI, prompt templates", async () => {
console.log(responseA.generations);
});

test.skip("ChatGoogleVertexAI, longer chain of messages", async () => {
test("ChatGoogleVertexAI, longer chain of messages", async () => {
const chat = new ChatGoogleVertexAI();

const chatPrompt = ChatPromptTemplate.fromPromptMessages([
Expand All @@ -67,7 +67,7 @@ test.skip("ChatGoogleVertexAI, longer chain of messages", async () => {
console.log(responseA.generations);
});

test.skip("ChatGoogleVertexAI, with a memory in a chain", async () => {
test("ChatGoogleVertexAI, with a memory in a chain", async () => {
const chatPrompt = ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(
"You are a helpful assistant who must always respond like a pirate"
Expand All @@ -94,3 +94,22 @@ test.skip("ChatGoogleVertexAI, with a memory in a chain", async () => {

console.log(response2);
});

test("CodechatGoogleVertexAI, chain of messages", async () => {
const chat = new ChatGoogleVertexAI({ model: "codechat-bison" });

const chatPrompt = ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(
`Answer all questions using Python and just show the code without an explanation.`
),
HumanMessagePromptTemplate.fromTemplate("{text}"),
]);

const responseA = await chat.generatePrompt([
await chatPrompt.formatPromptValue({
text: "How can I write a for loop counting to 10?",
}),
]);

console.log(JSON.stringify(responseA.generations, null, 1));
});
27 changes: 27 additions & 0 deletions langchain/src/chat_models/tests/chatgooglevertexai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,30 @@ test("Google Throw an error for an even number of non-system input messages", as
const model = new ChatGoogleVertexAI();
expect(() => model.createInstance(messages)).toThrow();
});

test("Google code messages", async () => {
const messages: BaseChatMessage[] = [
new HumanChatMessage("Human1"),
new AIChatMessage("AI1"),
new HumanChatMessage("Human2"),
];
const model = new ChatGoogleVertexAI({ model: "codechat-bison" });
const instance = model.createInstance(messages);
expect(instance.context).toBe("");
expect(instance.messages[0].author).toBe("user");
expect(instance.messages[1].author).toBe("system");
});

test("Google code messages with a system message", async () => {
const messages: BaseChatMessage[] = [
new SystemChatMessage("System1"),
new HumanChatMessage("Human1"),
new AIChatMessage("AI1"),
new HumanChatMessage("Human2"),
];
const model = new ChatGoogleVertexAI({ model: "codechat-bison" });
const instance = model.createInstance(messages);
expect(instance.context).toBe("System1");
expect(instance.messages[0].author).toBe("user");
expect(instance.messages[1].author).toBe("system");
});

1 comment on commit 516e3cd

@vercel
Copy link

@vercel vercel bot commented on 516e3cd Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.