Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Feature/streamable adapter #114

Merged
merged 5 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { AdapterRequest } from "./adapter";

export interface StreamableAdapter {
generateStreamableResponse(args: AdapterRequest): Promise<string>;
generateStreamableResponse(args: AdapterRequest): AsyncGenerator<string>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen that in a long time 😂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahahahahah yeah, i figured that streams should naturally be generators XD

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,36 @@ describe("GPT-4 Adapter", () => {
// Assert the message content to contain the word akeru
expect(result.toLocaleLowerCase()).toContain("akeru");
});

test("Returns streamable GPT-4 chat completions response", async () => {
// Arrange
const messages = [
{
role: "user" as Role,
content: "hello, who are you?",
},
];
const assistant_instructions =
"You're an AI assistant. You're job is to help the user. Always respond with the word akeru.";

const gpt4Adapter = AdapterManager.instance.getStreamableAdapter("gpt-4");

if (!gpt4Adapter) {
throw new Error("GPT-4 adapter not found");
}

const result = gpt4Adapter.generateStreamableResponse({
message_content: messages,
instruction: assistant_instructions,
});

const responses = [];
for await (const response of result) {
responses.push(response);
}

// Assert the message content to contain the word akeru
const generatedResponse = responses.join("");
expect(generatedResponse.toLocaleLowerCase()).toContain("akeru");
});
});
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { Message } from "@/core/domain/messages";
import { BaseAdapter, StreamableAdapter } from "@/infrastructure/adapters/BaseAdapter";
import {
BaseAdapter,
StreamableAdapter,
} from "@/infrastructure/adapters/BaseAdapter";
import { AdapterRequest } from "../adapter";
import { GPTModels } from "./models";

Expand All @@ -17,35 +20,6 @@ interface ChatResponseChoice {
export interface OpenAIResponse {
choices: ChatResponseChoice[];
}

// export async function gpt4Adapter(
// messages: Pick<Message, "role" | "content">[],
// assistant_instructions: string
// ): Promise<unknown> {
// // System will always be the assistant_instruction that created the assistant
// const gpt_messages = [
// { role: "system", content: assistant_instructions },
// ].concat(messages);
// try {
// const res = await fetch("https://api.openai.com/v1/chat/completions", {
// method: "POST",
// headers: {
// "Content-Type": "application/json",
// Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
// },
// body: JSON.stringify({
// model: "gpt-4",
// messages: gpt_messages,
// }),
// });

// const data: OpenAIResponse = await res.json();
// return data;
// } catch (error) {
// return new Response("Error", { status: 500 });
// }
// }

export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
adapterName: string;
adapterDescription = "This adapter supports all adapter models from OpenAI";
Expand All @@ -57,7 +31,9 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
}

async generateSingleResponse(args: AdapterRequest): Promise<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content);
const gpt_messages = [{ role: "system", content: args.instruction }].concat(
args.message_content
);
try {
const res = await fetch(this.OPENAI_ENDPOINT, {
method: "POST",
Expand All @@ -72,15 +48,43 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
});

const data: OpenAIResponse = await res.json();
const finished_inference = data.choices[0].message.content
const finished_inference = data.choices[0].message.content;
return Promise.resolve(finished_inference);
} catch (error) {
return Promise.reject(error);
}
}

async generateStreamableResponse(args: AdapterRequest): Promise<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(args.message_content);
async *chunksToLines(chunksAsync: any) {
let previous = "";
for await (const chunk of chunksAsync) {
const bufferChunk = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
previous += bufferChunk;
let eolIndex;
while ((eolIndex = previous.indexOf("\n")) >= 0) {
// line includes the EOL
const line = previous.slice(0, eolIndex + 1).trimEnd();
if (line === "data: [DONE]") break;
if (line.startsWith("data: ")) yield line;
previous = previous.slice(eolIndex + 1);
}
}
}

async *linesToMessages(linesAsync: any) {
for await (const line of linesAsync) {
const message = line.substring("data :".length);

yield message;
}
}

async *generateStreamableResponse(
args: AdapterRequest
): AsyncGenerator<string> {
const gpt_messages = [{ role: "system", content: args.instruction }].concat(
args.message_content
);
try {
const res = await fetch(this.OPENAI_ENDPOINT, {
method: "POST",
Expand All @@ -91,17 +95,32 @@ export class GPTAdapter extends BaseAdapter implements StreamableAdapter {
body: JSON.stringify({
model: this.adapterName,
messages: gpt_messages,
stream: true
stream: true,
}),
});

const data: OpenAIResponse = await res.json();
const finished_inference = data.choices[0].message.content
return Promise.resolve(finished_inference);
// This section of the code is taken from...
/// https://github.com/openai/openai-node/issues/18
const reader = res.body?.getReader();
while (true && reader) {
const { done, value } = await reader.read();
if (done) break;

const data = new TextDecoder().decode(value);
const chunkToLines = this.chunksToLines(data);
const linesToMessages = this.linesToMessages(chunkToLines);

for await (const message of linesToMessages) {
const messageObject: any = JSON.parse(message);

// MessageObject is the response from the OpenAI API streams
const messageToYield = messageObject.choices[0].delta.content;
if (messageToYield) yield messageToYield;
}
}
} catch (error) {
return Promise.reject(error);
}

}

getAdapterInformation(): Object {
Expand Down
Loading