diff --git a/.env.example b/.env.example index 0b229f0..ba71d1a 100644 --- a/.env.example +++ b/.env.example @@ -46,11 +46,18 @@ CLOUDFLARE_ACCOUNT_ID=abcdef1234567890abcdef1234567890 CLOUDFLARE_ACCOUNT_TOKEN=v1.0-abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890 +# OpenAI API: Specify an OpenAI account token for use with the OpenAI API. +OPENAI_ACCOUNT_TOKEN=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + # The following variables specify instruction sets and configuration for the various models in SpongeChat. ## The value of the variable should correspond to the key of its' responding configuration in modelConstants.js. MODEL_LLM_PRESET=default +## Callsystems are used to call functions during passive activation. +## Integrations is the newer, flexible function system. They are easily extendable, but require the use of the OpenAI API to determine function calls. +## Legacy is the older function system and only supports image generation. Use this if you can't use the OpenAI API. +MODEL_LLM_CALLSYSTEM=integrations -# !! Wastebin +# !! Wastebin # Used to display logs of memories for users ## In a docker-compose setup, you'll need to set up some sort of proxy (caddy, cloudflare tunnel) to make the "wastebin" container publicly accessible, and put the publicly accessible URL here. diff --git a/bun.lockb b/bun.lockb index 82d18f1..8da124c 100644 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 8a4f398..6734e02 100644 --- a/package.json +++ b/package.json @@ -1,45 +1,48 @@ { - "$schema": "https://json.schemastore.org/package.json", - "name": "spongechat", - "version": "2.0.1", - "private": true, - "type": "module", - "scripts": { - "lint": "prettier --check . && eslint --ext .js,.mjs,.cjs --format=pretty src", - "format": "prettier --write . && eslint --ext .js,.mjs,.cjs --fix --format=pretty src", - "start": "node --require dotenv/config src/index.js", - "cmd:undeploy": "node --require dotenv/config src/util/deploy/cli.js undeploy", - "cmd:deploy": "node --require dotenv/config src/util/deploy/cli.js deploy", - "ci:release": "dotenv -e .env -- release-it --config .release-it.cjs" - }, - "_moduleAliases": { - "@": "./src", - "@util": "./src/util", - "@events": "./src/events", - "@commands": "./src/commands" - }, - "dependencies": { - "@discordjs/core": "^1.1.0", - "@redis/json": "^1.0.6", - "chalk": "4", - "discord.js": "^14.15.2", - "dotenv": "^16.3.1", - "luxon": "^3.4.4", - "module-alias": "^2.2.3", - "redis": "^4.6.13", - "temporal-polyfill": "^0.2.4", - "undici": "^6.16.1", - "uuid": "^9.0.1" - }, - "devDependencies": { - "@release-it/conventional-changelog": "^8.0.1", - "all-contributors-cli": "^6.26.1", - "dotenv-cli": "^7.4.2", - "eslint": "^8.53.0", - "eslint-config-neon": "^0.1.57", - "eslint-formatter-pretty": "^5.0.0", - "execa": "^9.1.0", - "prettier": "^3.0.3", - "release-it": "^17" - } + "$schema": "https://json.schemastore.org/package.json", + "name": "spongechat", + "version": "2.0.1", + "private": true, + "type": "module", + "scripts": { + "lint": "prettier --check . && eslint --ext .js,.mjs,.cjs --format=pretty src", + "format": "prettier --write . && eslint --ext .js,.mjs,.cjs --fix --format=pretty src", + "start": "node --require dotenv/config src/index.js", + "cmd:undeploy": "node --require dotenv/config src/util/deploy/cli.js undeploy", + "cmd:deploy": "node --require dotenv/config src/util/deploy/cli.js deploy", + "ci:release": "dotenv -e .env -- release-it --config .release-it.cjs" + }, + "_moduleAliases": { + "@": "./src", + "@util": "./src/util", + "@events": "./src/events", + "@commands": "./src/commands" + }, + "dependencies": { + "@ai-sdk/openai": "^0.0.13", + "@discordjs/core": "^1.1.0", + "@redis/json": "^1.0.6", + "ai": "^3.1.12", + "chalk": "4", + "discord.js": "^14.15.2", + "dotenv": "^16.3.1", + "luxon": "^3.4.4", + "module-alias": "^2.2.3", + "redis": "^4.6.13", + "temporal-polyfill": "^0.2.4", + "undici": "^6.16.1", + "uuid": "^9.0.1", + "zod": "^3.23.8" + }, + "devDependencies": { + "@release-it/conventional-changelog": "^8.0.1", + "all-contributors-cli": "^6.26.1", + "dotenv-cli": "^7.4.2", + "eslint": "^8.53.0", + "eslint-config-neon": "^0.1.57", + "eslint-formatter-pretty": "^5.0.0", + "execa": "^9.1.0", + "prettier": "^3.0.3", + "release-it": "^17" + } } diff --git a/src/commands/instructionSet.js b/src/commands/instructionSet.js index 7a926a6..8a050d7 100644 --- a/src/commands/instructionSet.js +++ b/src/commands/instructionSet.js @@ -11,7 +11,7 @@ export default { o .setName("preset") .setDescription("Preset; map to => client.tempStore#instructionSet") - .setChoices(Object.keys(instructionSets).map((s) => ({ name: s, value: s }))) + .setChoices(Object.keys(instructionSets).map((s) => ({ name: instructionSets[s]?.name || s, value: s }))) .setRequired(true), ) .toJSON(), @@ -35,7 +35,7 @@ export default { }); console.log( - `${chalk.bold.green("AI")} Instruction set preset changed to ${chalk.bold(toOption)} (${Temporal.Now.instant().toLocaleString("en-GB", { timeZone: "Etc/UTC", timeZoneName: "short" })})`, + `${chalk.bold.green("AI")} Instruction set preset changed to ${chalk.bold(instructionSets[toOption]?.name || toOption)} (${Temporal.Now.instant().toLocaleString("en-GB", { timeZone: "Etc/UTC", timeZoneName: "short" })})`, ); if (sync) { diff --git a/src/events/messageCreate.js b/src/events/messageCreate.js index af098be..06355b6 100644 --- a/src/events/messageCreate.js +++ b/src/events/messageCreate.js @@ -9,7 +9,9 @@ const callTextChannel = async ({ client, message }) => { baseHistory: [], accountId: process.env.CLOUDFLARE_ACCOUNT_ID, token: process.env.CLOUDFLARE_ACCOUNT_TOKEN, + openaiToken: process.env.OPENAI_ACCOUNT_TOKEN, model: "@cf/meta/llama-3-8b-instruct", + callsystem: process.env.MODEL_LLM_CALLSYSTEM || "legacy", }); const preliminaryConditions = modelInteractions.messageEvent.checkPreliminaryConditions(); @@ -40,20 +42,45 @@ const callTextChannel = async ({ client, message }) => { }) .catch(console.error); - const { textResponse, genData, callResponse } = await modelInteractions.messageEvent.handleTextModelCall({ history }); + const { legacy, runners, response } = await modelInteractions.messageEvent.preSend({ history }); - if (callResponse.length === 0 || callResponse === "") return await message.react("⚠️").catch(() => false); + if (legacy?.active) { + const { textResponse, genData, callResponse } = legacy; + if (callResponse.length === 0 || callResponse === "") return await message.react("⚠️").catch(() => false); - const { responseMsg, events } = await modelInteractions.messageEvent.createResponse({ - textResponse, - conditions: { - amnesia: !validityCheck?.valid && validityCheck?.handled?.isRequired && validityCheck?.handled?.executed, - imagine: callResponse.includes("!gen"), - }, - }); + const { responseMsg, events } = await modelInteractions.messageEvent.createLegacyResponse({ + textResponse, + conditions: { + amnesia: !validityCheck?.valid && validityCheck?.handled?.isRequired && validityCheck?.handled?.executed, + imagine: callResponse.includes("!gen"), + }, + }); + + if (responseMsg && callResponse.includes("!gen")) + return await modelInteractions.messageEvent.handleLegacyImageModelCall({ + genData, + textResponse, + responseMsg, + events, + }); + + return; + } + + if (response?.length === 0 || response === "") return await message.react("⚠️").catch(() => false); - if (responseMsg && callResponse.includes("!gen")) - return await modelInteractions.messageEvent.handleImageModelCall({ genData, textResponse, responseMsg, events }); + const replyContent = modelInteractions.response.format(response); + const reply = await message + .reply({ content: replyContent.content, files: replyContent.files, failIfNotExists: true }) + .catch(() => false); + + if (runners.length > 0) { + const postRunners = await modelInteractions.messageEvent.postSend({ runners, message: reply }); + const mergedFiles = [...replyContent.files, ...postRunners.results]; + return await reply + .edit({ content: replyContent.content, files: mergedFiles, failIfNotExists: true }) + .catch(() => false); + } }; /** @type {import('./index.js').Event} */ diff --git a/src/events/ready.js b/src/events/ready.js index 2ca292a..dd7227e 100644 --- a/src/events/ready.js +++ b/src/events/ready.js @@ -1,6 +1,7 @@ import { Events } from "discord.js"; import { Environment } from "../util/helpers.js"; import { createClient } from "redis"; +import { instructionSets } from "../util/models/constants.js"; import chalk from "chalk"; const env = new Environment(); @@ -42,7 +43,9 @@ export default { client.tempStore.set("instructionSet", instructionSet); console.log(`${chalk.bold.green("AI")} Silent mode is ${chalk.bold(silentSaved ? "enabled" : "disabled")}`); - console.log(`${chalk.bold.green("AI")} Instruction set is ${chalk.bold(instructionSet)}`); + console.log( + `${chalk.bold.green("AI")} Instruction set is ${chalk.bold(instructionSets[instructionSet]?.name || instructionSet)}`, + ); console.log( `${chalk.bold.green("Core")} acting as ${chalk.bold(client.user.tag)} (${Temporal.Now.instant().toLocaleString("en-GB", { timeZone: "Etc/UTC", timeZoneName: "short" })})`, diff --git a/src/index.js b/src/index.js index 869af27..5ad03cb 100644 --- a/src/index.js +++ b/src/index.js @@ -12,7 +12,10 @@ import "temporal-polyfill/global"; (() => { console.log(`${chalk.bold.green("Core")} running with environment context: ${chalk.bold(process.env.NODE_ENV)}`); - console.log(`${chalk.bold.magenta("AI")} running with LLM preset: ${chalk.bold(process.env.MODEL_LLM_PRESET)}`); + if (process.env.MODEL_LLM_CALLSYSTEM !== "integrations") + console.log( + `${chalk.bold.magenta("AI")} ${chalk.yellow("Warning")} The legacy call system is enabled. Integration calls are not available in this mode.`, + ); })(); // Initialize the client diff --git a/src/util/integrations/index.js b/src/util/integrations/index.js new file mode 100644 index 0000000..a9eefb6 --- /dev/null +++ b/src/util/integrations/index.js @@ -0,0 +1,82 @@ +import { tool } from "ai"; +import { z } from "zod"; + +export class Integration { + constructor({ name, description, parameters, stage }) { + this.tool = tool({ + description, + parameters, + }); + + this.executionLevel = stage; + } + + get stage() { + return this.executionLevel; + } + + // pre-runner integrations run before the model call and can ONLY return a conversation-based output; () => Promise + // post-runner integrations run after the model call and can only return file-based outputs; () => Promise + async call() { + return {}; + } +} + +export class ImagineIntegration extends Integration { + constructor({ workersAI }) { + super({ + name: "imagine", + description: "Generate an image with the given prompt", + parameters: z.object({ + prompt: z.string().describe("The prompt to use for generating the image"), + }), + stage: "post", + }); + + this.workersAI = workersAI; + } + + async call({ prompt }, ctx) { + const callToModel = await this.workersAI + .callModel( + { + model: "@cf/lykon/dreamshaper-8-lcm", + input: { + prompt, + }, + }, + true, + ) + .then((r) => r.arrayBuffer()) + .catch(() => (e) => { + console.error(e); + return null; + }); + + if (callToModel === null) return null; + + const buffer = Buffer.from(callToModel); + + return buffer; + } +} + +export class QoTDIntegration extends Integration { + constructor() { + super({ + name: "qotd", + description: "Get the quote of the day", + parameters: z.object({ + luckyWord: z.string().describe("The lucky word to randomise the quote with"), + }), + stage: "pre", + }); + } + + async call({ prompt }, ctx) { + return { + role: "system", + content: "[Function call to QOTD]: The quote of the day is skeebeedee guyatt toilet.", + }; + } +} diff --git a/src/util/models/constants.js b/src/util/models/constants.js index 312d766..d0bc7ee 100644 --- a/src/util/models/constants.js +++ b/src/util/models/constants.js @@ -1,3 +1,6 @@ +import { tool } from "ai"; +import { z } from "zod"; + export const instructionSets = { default: { name: "Default", diff --git a/src/util/models/index.js b/src/util/models/index.js index f1dd258..b3242d4 100644 --- a/src/util/models/index.js +++ b/src/util/models/index.js @@ -1,5 +1,10 @@ import { fetch } from "undici"; -import { InteractionResponse, InteractionHistory, InteractionMessageEvent } from "./interactions.js"; +import { + InteractionResponse, + InteractionHistory, + InteractionMessageEvent, + InteractionIntegrations, +} from "./interactions.js"; export class WorkersAI { constructor( @@ -55,14 +60,17 @@ export class ModelInteractions { this.disabledModules = disabledModules; this.history = disabledModules?.includes("history") ? null : new InteractionHistory(opts); this.response = disabledModules?.includes("response") ? null : new InteractionResponse(opts); + this.integrations = disabledModules?.includes("integrations") ? null : new InteractionIntegrations(opts); this.messageEvent = disabledModules?.includes("messageEvent") ? null : new InteractionMessageEvent({ ...opts, interactionResponse: this.response, interactionHistory: this.history, + interactionIntegrations: this.integrations, }); this.model = opts?.model; + this.callsystem = opts?.callsystem; } } diff --git a/src/util/models/interactions.js b/src/util/models/interactions.js index e736c08..8e4fb3e 100644 --- a/src/util/models/interactions.js +++ b/src/util/models/interactions.js @@ -1,6 +1,10 @@ import { fetch } from "undici"; import { events, instructionSets } from "./constants.js"; import { WorkersAI } from "./index.js"; +import { createOpenAI } from "@ai-sdk/openai"; +import { generateText } from "ai"; + +import { ImagineIntegration, QoTDIntegration } from "../integrations/index.js"; export class InteractionHistory { constructor( @@ -14,31 +18,39 @@ export class InteractionHistory { ) { this.kv = kv; this.contextWindow = contextWindow || 10; - this.instructionSet = instructionSets[instructionSet || "default"]; + this.instructionSet = { + id: instructionSet, + ...instructionSets[instructionSet || "default"], + }; this.baseHistory = [ + ...(this.instructionSet?.instructions || [ + { + role: "system", + content: this?.instructionSet, + }, + ] || + []), ...baseHistory, - { - role: "system", - content: this.instructionSet, - }, ]; this.model = model; } - async get({ key }, all = false) { + async get({ key, instructionSet = this.instructionSet?.id, window = this.contextWindow }, all = false) { + const baseHistory = instructionSets[instructionSet]?.instructions || this.baseHistory; const fetchedMessages = (await this.kv.lRange(key, 0, -1)) + .sort((a, b) => new Date(b.timestamp) - new Date(a.timestamp)) .reverse() .map((m) => JSON.parse(m)) // only return the last [contextWindow] messages // if all is true, return all messages - .slice(0, all ? -1 : this.contextWindow) + .slice(0, all ? -1 : window || this.contextWindow) .reduce((acc, item, index) => { // this reducer is very.. redundant, but i'm adding it for later acc.push(item); return acc; }, []); - return [...this.baseHistory, ...fetchedMessages]; + return [...baseHistory, ...fetchedMessages]; } async add( @@ -75,7 +87,9 @@ export class InteractionHistory { .lRange(key, 0, -1) .then((r) => r.map((m) => JSON.parse(m))) .catch(() => []) - ).reverse(); + ) + .sort((a, b) => new Date(b.timestamp) - new Date(a.timestamp)) + .reverse(); const interactions = current?.filter(typeof filter === "function" ? filter : (f) => f); const formatted = interactions @@ -188,7 +202,7 @@ export class InteractionResponse { } formatAssistantMessage(content) { - return content.trim(); + return content?.trim(); } /** @@ -207,12 +221,12 @@ export class InteractionResponse { * @param {string} event.status The status of the event * @returns {string} The formatted message * @example - * const message = await this.formatOutputMessage(content, event); + * const message = await this.formatLegacyOutput(content, event); * console.log(message); * // Outputs the formatted message */ - formatOutputMessage(content, allEvents = []) { + formatLegacyOutput(content, allEvents = []) { const bannerArr = allEvents .map((event) => { const eventData = events[event?.type]; @@ -227,18 +241,103 @@ export class InteractionResponse { return banner + "\n" + content.trim(); } + format(input) { + if (!input) + return { + content: "", + files: [], + }; + + const content = input?.length >= 2000 ? "" : input; + const files = input?.length >= 2000 ? [{ attachment: Buffer.from(text, "utf-8"), name: "response.md" }] : []; + + return { + content, + files, + }; + } + currentTemporalISO() { return Temporal.Now.plainDateTimeISO(this?.tz || "Etc/UTC").toString(); } } +export class InteractionIntegrations { + constructor( + { message, kv, model, accountId, token, openaiToken, callsystem } = { + kv: null, + instructionSet: process.env.MODEL_LLM_PRESET || "default", + baseHistory: [], + model: "@cf/meta/llama-3-8b-instruct", + contextWindow: 10, + callsystem: process.env.MODEL_LLM_CALLSYSTEM || "legacy", + }, + ) { + this.message = message; + this.kv = kv; + this.workersAI = new WorkersAI({ accountId, token, model }); + this.openai = createOpenAI({ + apiKey: openaiToken, + }); + this.model = model; + this.callsystem = callsystem; + + this.integrations = { + imagine: new ImagineIntegration({ workersAI: this.workersAI }), + quoteoftheday: new QoTDIntegration(), + }; + } + + get integrationSchemas() { + return Object.keys(this.integrations).reduce((acc, cv) => { + return { + ...acc, + [cv]: this.integrations[cv].tool, + }; + }, {}); + } + + async integrationCaller({ history }) { + if (this.callsystem === "legacy") return []; + const model = this.openai.chat("gpt-3.5-turbo", { + user: this.message?.author?.id, + }); + + const call = await generateText({ + model, + system: + "You are a bot that can call functions. If no functions are required, respond with []. The previous user messages are only for context, you have already answered them.", + messages: history, + tools: this.integrationSchemas, + }) + .then((r) => r.toolCalls) + .catch(() => []); + + return call; + } + + async execute({ calls, ctx }) { + if (calls.length === 0 || this.callsystem === "legacy") return []; + // for each integration, call the integration + return Promise.all( + calls.map(async (call) => { + const integration = this.integrations[call.toolName]; + if (typeof integration?.call !== "function") return; + return await integration.call(call.args, ctx); + }), + ); + } +} + export class InteractionMessageEvent { - constructor({ message, interactionResponse, interactionHistory, model }) { + constructor({ message, interactionResponse, interactionHistory, interactionIntegrations, callsystem, model }) { this.message = message; this.client = message?.client; this.author = message?.author; this.response = interactionResponse; this.history = interactionHistory; + this.integrations = interactionIntegrations; + this.callsystem = callsystem; this.model = model; } @@ -261,7 +360,7 @@ export class InteractionMessageEvent { } async validateHistory() { - const initialHistory = (await this.history.get({ key: this.message?.channel?.id })).filter( + const initialHistory = (await this.history.get({ key: this.message?.channel?.id }, true)).filter( (e) => e.role === "assistant", ); @@ -321,47 +420,8 @@ export class InteractionMessageEvent { }; } - async handleTextModelCall({ history }) { - await this.message?.channel?.sendTyping(); - const modelCall = await this.response.workersAI - .callModel({ - input: { - messages: history.map((e) => ({ - role: e.role, - content: e.content, - })), - }, - maxTokens: 512, - }) - .catch(() => ({ - result: { response: "" }, - })); - - const callResponse = modelCall?.result?.response?.trim(); - const textResponse = callResponse?.split("!gen")?.[0]; - const genData = callResponse?.split("!gen")?.[1]?.replace("[", "").replace("]", ""); - - await this.history - .add( - { - key: this.message?.channel?.id, - role: "assistant", - content: this.response.formatAssistantMessage(textResponse?.length === 0 ? "[no response]" : textResponse), - respondingTo: this.message?.id, - }, - true, - ) - .catch(console.error); - - return { - textResponse, - genData, - callResponse, - }; - } - - async handleImageModelCall({ genData, textResponse, responseMsg, events }) { - const final = this.response.formatOutputMessage( + async handleLegacyImageModelCall({ genData, textResponse, responseMsg, events }) { + const final = this.response.formatLegacyOutput( textResponse, events.filter((e) => e.type !== "imagine"), ); @@ -408,7 +468,7 @@ export class InteractionMessageEvent { .catch(() => null); } - async createResponse( + async createLegacyResponse( { textResponse, conditions } = { conditions: { amnesia: false, @@ -426,7 +486,7 @@ export class InteractionMessageEvent { }; }); - const text = this.response.formatOutputMessage(textResponse, events); + const text = this.response.formatLegacyOutput(textResponse, events); const content = textResponse.length >= 2000 ? "" : text; const files = textResponse.length >= 2000 ? [{ attachment: Buffer.from(text, "utf-8"), name: "response.md" }] : []; @@ -443,4 +503,83 @@ export class InteractionMessageEvent { events, }; } + + async preSend({ history }) { + const callContext = await this.history.get({ key: this.message?.channel?.id }, true).catch(() => []); + const calls = await this.integrations + .integrationCaller({ + history: callContext + .map((e) => ({ + role: e.role, + content: e.content, + })) + .filter((e) => e.role === "user") + .slice(-2), + }) + .then((r) => + r.map((c) => ({ + ...c, + stage: this.integrations.integrations?.[c.toolName]?.stage, + execute: async () => { + return await this.integrations.integrationSchemas?.[c.toolName]?.call(c.args); + }, + })), + ); + const preRunners = calls.filter((c) => c.stage === "pre"); + const postRunners = calls.filter((c) => c.stage === "post"); + const preRunnerResults = await this.integrations.execute({ calls: preRunners }).catch(() => []); + const allMessages = [...history.slice(0, -1), ...preRunnerResults, ...history.slice(-1)]; + + await this.message?.channel?.sendTyping(); + const modelCall = await this.response.workersAI + .callModel({ + input: { + messages: allMessages.map((e) => ({ + role: e.role, + content: e.content, + })), + }, + maxTokens: 512, + }) + .catch(() => ({ + result: { response: "" }, + })); + + const response = modelCall?.result?.response?.trim(); + + await this.history + .add( + { + key: this.message?.channel?.id, + role: "assistant", + content: this.response.formatAssistantMessage(response?.length === 0 ? "[no response]" : response), + respondingTo: this.message?.id, + context: { + integrations: calls.map((c) => ({ id: c.toolName, stage: c.stage, args: c.args })), + }, + }, + true, + ) + .catch(console.error); + + return { + legacy: { + active: this.callsystem === "legacy", + textResponse: response, + genData: response?.split("!gen")?.[0], + callResponse: response?.split("!gen")?.[1]?.replace("[", "").replace("]", ""), + }, + runners: postRunners, + response, + }; + } + + async postSend({ runners, message }) { + await this?.message?.react; + const runnerResults = await this.integrations.execute({ calls: runners, ctx: { message } }).catch(() => []); + + return { + results: runnerResults.filter((r) => r !== null), + }; + } }