From 04ad081d7532200198d830be4756d08b4bf86558 Mon Sep 17 00:00:00 2001 From: xswl Date: Fri, 14 Jul 2023 18:50:11 +0800 Subject: [PATCH] add face swap --- .../{releases-4.2.yaml => v4.3.x.yaml} | 4 +- README.md | 103 +- example/faceswap.ts | 33 + package.json | 1 + src/discord.ws.ts | 16 +- src/face.swap.ts | 24 + src/gradio/client.ts | 1336 +++++++++++++++++ src/gradio/globals.d.ts | 31 + src/gradio/index.ts | 8 + src/gradio/types.ts | 109 ++ src/gradio/utils.ts | 211 +++ src/index.ts | 1 + src/interfaces/message.ts | 8 + src/midjourne.api.ts | 28 +- src/midjourney.ts | 28 +- src/utls/index.ts | 21 + test/face.ts | 25 + yarn.lock | 5 + 18 files changed, 1939 insertions(+), 53 deletions(-) rename .github/workflows/{releases-4.2.yaml => v4.3.x.yaml} (93%) create mode 100644 example/faceswap.ts create mode 100644 src/face.swap.ts create mode 100644 src/gradio/client.ts create mode 100644 src/gradio/globals.d.ts create mode 100644 src/gradio/index.ts create mode 100644 src/gradio/types.ts create mode 100644 src/gradio/utils.ts create mode 100644 test/face.ts diff --git a/.github/workflows/releases-4.2.yaml b/.github/workflows/v4.3.x.yaml similarity index 93% rename from .github/workflows/releases-4.2.yaml rename to .github/workflows/v4.3.x.yaml index 2431cb3..a658e07 100644 --- a/.github/workflows/releases-4.2.yaml +++ b/.github/workflows/v4.3.x.yaml @@ -1,6 +1,6 @@ -name: Package +name: verson 4.3.x env: - APPVERSION: 4.2.${{ github.run_number }} + APPVERSION: 4.3.${{ github.run_number }} on: workflow_dispatch: push: diff --git a/README.md b/README.md index 69df071..f525cb4 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # midjourney-api -Node.js client for the unofficial MidJourney api. -Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts) & [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts) & [remix mode](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts) +Node.js client for the unofficial MidJourney api. +

Discord server @@ -9,6 +9,11 @@ Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/

+## What's new + +- [face swap](https://github.com/erictik/midjourney-api/blob/main/example/faceswap.ts) +- Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts) & [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts) & [remix mode](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts) + ## Install ```bash @@ -21,40 +26,42 @@ yarn add midjourney ```typescript import { Midjourney } from "midjourney"; - const client = new Midjourney({ - ServerId: process.env.SERVER_ID, - ChannelId: process.env.CHANNEL_ID, - SalaiToken: process.env.SALAI_TOKEN, - Debug: true, - Ws:true, - }); - await client.Connect(); - const Imagine = await client.Imagine("A little pink elephant", (uri: string, progress:string) => { - console.log("Imagine", uri, "progress", progress); - }); - console.log({ Imagine }); - - const Variation = await client.Variation({ - index: 2, - msgId: Imagine.id, - hash: Imagine.hash, - flags: Imagine.flags, - loading: (uri: string, progress: string) => { - console.log("Variation.loading", uri, "progress", progress); - }, - }); - console.log({ Variation }); - const Upscale = await client.Upscale({ - index: 2, - msgId: Variation.id, - hash: Variation.hash, - flags: Variation.flags, - loading: (uri: string, progress: string) => { - console.log("Upscale.loading", uri, "progress", progress); - }, - }); - console.log({ Upscale }); - +const client = new Midjourney({ + ServerId: process.env.SERVER_ID, + ChannelId: process.env.CHANNEL_ID, + SalaiToken: process.env.SALAI_TOKEN, + Debug: true, + Ws: true, +}); +await client.Connect(); +const Imagine = await client.Imagine( + "A little pink elephant", + (uri: string, progress: string) => { + console.log("Imagine", uri, "progress", progress); + } +); +console.log({ Imagine }); + +const Variation = await client.Variation({ + index: 2, + msgId: Imagine.id, + hash: Imagine.hash, + flags: Imagine.flags, + loading: (uri: string, progress: string) => { + console.log("Variation.loading", uri, "progress", progress); + }, +}); +console.log({ Variation }); +const Upscale = await client.Upscale({ + index: 2, + msgId: Variation.id, + hash: Variation.hash, + flags: Variation.flags, + loading: (uri: string, progress: string) => { + console.log("Upscale.loading", uri, "progress", progress); + }, +}); +console.log({ Upscale }); ``` ## Example @@ -77,12 +84,12 @@ npm install ``` 3. set the environment variables - - [How to get your Discord TOKEN:](https://www.androidauthority.com/get-discord-token-3149920/) - - [Create a server](https://discord.com/blog/starting-your-first-discord-server) and [Invite Midjourney Bot to Your Server](https://docs.midjourney.com/docs/invite-the-bot) - OR [Join my discord server](https://discord.com/invite/GavuGHQbV4) - - How to get server and channel ids: - when you click on a channel in your server in the browser expect to have the follow URL pattern `https://discord.com/channels/$SERVER_ID/$CHANNEL_ID` +- [How to get your Discord TOKEN:](https://www.androidauthority.com/get-discord-token-3149920/) +- [Create a server](https://discord.com/blog/starting-your-first-discord-server) and [Invite Midjourney Bot to Your Server](https://docs.midjourney.com/docs/invite-the-bot) + OR [Join my discord server](https://discord.com/invite/GavuGHQbV4) +- How to get server and channel ids: + when you click on a channel in your server in the browser expect to have the follow URL pattern `https://discord.com/channels/$SERVER_ID/$CHANNEL_ID` ```bash #example variables, please set up yours @@ -99,12 +106,13 @@ npx tsx example/imagine-ws.ts ``` ## route-map + - [x] `/imagine` `variation` `upscale` `reroll` `blend` `zoomout` `vary` - [x] `/info` - [x] `/fast ` and `/relax ` -- [x] [`/prefer remix`](https://github.com/erictik/midjourney-api/blob/main/example/prefer-remix.ts) +- [x] [`/prefer remix`](https://github.com/erictik/midjourney-api/blob/main/example/prefer-remix.ts) - [x] [`variation (remix mode)`](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts) -- [x] `/describe` +- [x] `/describe` - [x] [`/shorten`](https://github.com/erictik/midjourney-api/blob/main/example/shorten.ts) - [x] `/settings` `reset` - [x] verify human @@ -112,13 +120,22 @@ npx tsx example/imagine-ws.ts - [x] [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts) - [x] [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts) - [x] autoload command payload + --- + ## Projects + - [midjourney-ui](https://github.com/erictik/midjourney-ui/) - [midjourney-discord](https://github.com/erictik/midjourney-discord)-bot - [phrame](https://github.com/jakowenko/phrame) - [guapitu](https://www.guapitu.com/zh/draw?code=RRXQNF) + --- + +## Support + Buy Me a Coffee + ## Star History + [![Star History Chart](https://api.star-history.com/svg?repos=erictik/midjourney-api&type=Date)](https://star-history.com/#erictik/midjourney-api&Date) diff --git a/example/faceswap.ts b/example/faceswap.ts new file mode 100644 index 0000000..309df18 --- /dev/null +++ b/example/faceswap.ts @@ -0,0 +1,33 @@ +import "dotenv/config"; +import { Midjourney, detectBannedWords } from "../src"; +/** + * + * a simple example of how to use faceSwap + * ``` + * npx tsx example/faceswap.ts + * ``` + */ +async function main() { + const source = `https://cdn.discordapp.com/attachments/1107965981839605792/1129362418775113789/3829c5d7-3e7e-473c-9c7b-b858e3ec97bc.jpeg`; + // const source = `https://cdn.discordapp.com/attachments/1108587422389899304/1129321826804306031/guapitu006_Cute_warrior_girl_in_the_style_of_Baten_Kaitos__111f39bc-329e-4fab-9af7-ee219fedf260.png`; + const target = `https://cdn.discordapp.com/attachments/1108587422389899304/1129321837042602016/guapitu006_a_girls_face_with_david_bowies_thunderbolt_71ee5899-bd45-4fc4-8c9d-92f19ddb0a03.png`; + const client = new Midjourney({ + ServerId: process.env.SERVER_ID, + ChannelId: process.env.CHANNEL_ID, + SalaiToken: process.env.SALAI_TOKEN, + Debug: true, + HuggingFaceToken: process.env.HUGGINGFACE_TOKEN, + }); + + const info = await client.FaceSwap(target, source); + console.log(info); +} +main() + .then(() => { + console.log("finished"); + process.exit(0); + }) + .catch((err) => { + console.error(err); + process.exit(1); + }); diff --git a/package.json b/package.json index 8997014..f915e85 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,7 @@ "@huggingface/inference": "^2.5.0", "async": "^3.2.4", "isomorphic-ws": "^5.0.0", + "semiver": "^1.1.0", "snowyflake": "^2.0.0", "tslib": "^2.5.0", "ws": "^8.13.0" diff --git a/src/discord.ws.ts b/src/discord.ws.ts index 04be61a..fec1c16 100644 --- a/src/discord.ws.ts +++ b/src/discord.ws.ts @@ -9,6 +9,7 @@ import { MJOptions, OnModal, MJShorten, + MJDescribe, } from "./interfaces"; import { MidjourneyApi } from "./midjourne.api"; import { @@ -201,10 +202,16 @@ export class WsMessage { this.emit("settings", message); return; case "describe": - this.emitMJ(id, { + // console.log("describe", "meseesage", message); + const describe: MJDescribe = { + id: id, + flags: message.flags, descriptions: embeds?.[0]?.description.split("\n\n"), + uri: embeds?.[0]?.image?.url, + proxy_url: embeds?.[0]?.image?.proxy_url, options: formatOptions(components), - }); + }; + this.emitMJ(id, describe); break; case "prefer remix": if (content != "") { @@ -606,10 +613,7 @@ export class WsMessage { }); } async waitDescribe(nonce: string) { - return new Promise<{ - options: MJOptions[]; - descriptions: string[]; - } | null>((resolve) => { + return new Promise((resolve) => { this.onceMJ(nonce, (message) => { resolve(message); }); diff --git a/src/face.swap.ts b/src/face.swap.ts new file mode 100644 index 0000000..2b99709 --- /dev/null +++ b/src/face.swap.ts @@ -0,0 +1,24 @@ +import { client } from "./gradio/index"; + +export class faceSwap { + public hf_token?: string; + constructor(hf_token?: string) { + this.hf_token = hf_token; + } + async changeFace(Target: Blob, Source: Blob) { + const app = await client("https://felixrosberg-face-swap.hf.space/", { + hf_token: this.hf_token as any, + }); + // console.log("app", app); + const result: any = await app.predict(1, [ + Target, // blob in 'Target' Image component + Source, // blob in 'Source' Image component + 0, // number (numeric value between 0 and 100) in 'Anonymization ratio (%)' Slider component + 0, // number (numeric value between 0 and 100) in 'Adversarial defense ratio (%)' Slider component + "Compare", // string[] (array of strings) in 'Mode' Checkboxgroup component + ]); + // result.data; + return result.data; + // console.log(result.data[0]); + } +} diff --git a/src/gradio/client.ts b/src/gradio/client.ts new file mode 100644 index 0000000..03f4a57 --- /dev/null +++ b/src/gradio/client.ts @@ -0,0 +1,1336 @@ +import semiver from "semiver"; + +import { + process_endpoint, + RE_SPACE_NAME, + map_names_to_ids, + discussions_enabled, + get_space_hardware, + set_space_hardware, + set_space_timeout, + hardware_types, +} from "./utils"; + +import type { + EventType, + EventListener, + ListenerMap, + Event, + Payload, + PostResponse, + UploadResponse, + Status, + SpaceStatus, + SpaceStatusCallback, + FileData, +} from "./types"; + +import type { Config } from "./types"; +import WebSocket from "isomorphic-ws"; + +type event = ( + eventType: K, + listener: EventListener +) => SubmitReturn; +type predict = ( + endpoint: string | number, + data?: unknown[], + event_data?: unknown +) => Promise; + +type client_return = { + predict: predict; + config: Config; + submit: ( + endpoint: string | number, + data?: unknown[], + event_data?: unknown + ) => SubmitReturn; + view_api: (c?: Config) => Promise>; +}; + +type SubmitReturn = { + on: event; + off: event; + cancel: () => Promise; + destroy: () => void; +}; + +const QUEUE_FULL_MSG = "This application is too busy. Keep trying!"; +const BROKEN_CONNECTION_MSG = "Connection errored out."; + +export let NodeBlob: Blob; + +export async function duplicate( + app_reference: string, + options: { + hf_token: `hf_${string}`; + private?: boolean; + status_callback: SpaceStatusCallback; + hardware?: (typeof hardware_types)[number]; + timeout?: number; + } +) { + const { hf_token, private: _private, hardware, timeout } = options; + + if (hardware && !hardware_types.includes(hardware)) { + throw new Error( + `Invalid hardware type provided. Valid types are: ${hardware_types + .map((v) => `"${v}"`) + .join(",")}.` + ); + } + const headers = { + Authorization: `Bearer ${hf_token}`, + }; + + const user = ( + await ( + await fetch(`https://huggingface.co/api/whoami-v2`, { + headers, + }) + ).json() + ).name; + + const space_name = app_reference.split("/")[1]; + const body: { + repository: string; + private?: boolean; + } = { + repository: `${user}/${space_name}`, + }; + + if (_private) { + body.private = true; + } + + try { + const response = await fetch( + `https://huggingface.co/api/spaces/${app_reference}/duplicate`, + { + method: "POST", + headers: { "Content-Type": "application/json", ...headers }, + body: JSON.stringify(body), + } + ); + + if (response.status === 409) { + return client(`${user}/${space_name}`, options); + } else { + const duplicated_space = await response.json(); + + let original_hardware; + + if (!hardware) { + original_hardware = await get_space_hardware(app_reference, hf_token); + } + + const requested_hardware = hardware || original_hardware || "cpu-basic"; + await set_space_hardware( + `${user}/${space_name}`, + requested_hardware, + hf_token + ); + + await set_space_timeout( + `${user}/${space_name}`, + timeout || 300, + hf_token + ); + return client(duplicated_space.url, options); + } + } catch (e: any) { + throw new Error(e); + } +} + +/** + * We need to inject a customized fetch implementation for the Wasm version. + */ +export function api_factory(fetch_implementation: typeof fetch) { + return { post_data, upload_files, client, handle_blob }; + + async function post_data( + url: string, + body: unknown, + token?: `hf_${string}` + ): Promise<[PostResponse, number]> { + const headers: { + Authorization?: string; + "Content-Type": "application/json"; + } = { "Content-Type": "application/json" }; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + try { + var response = await fetch_implementation(url, { + method: "POST", + body: JSON.stringify(body), + headers, + }); + } catch (e) { + return [{ error: BROKEN_CONNECTION_MSG }, 500]; + } + const output: PostResponse = await response.json(); + return [output, response.status]; + } + + async function upload_files( + root: string, + files: Array, + token?: `hf_${string}` + ): Promise { + const headers: { + Authorization?: string; + } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + + const formData = new FormData(); + files.forEach((file) => { + formData.append("files", file); + }); + try { + var response = await fetch_implementation(`${root}/upload`, { + method: "POST", + body: formData, + headers, + }); + } catch (e) { + return { error: BROKEN_CONNECTION_MSG }; + } + const output: UploadResponse["files"] = await response.json(); + return { files: output }; + } + + async function client( + app_reference: string, + options: { + hf_token?: `hf_${string}`; + status_callback?: SpaceStatusCallback; + normalise_files?: boolean; + } = { normalise_files: true } + ): Promise { + return new Promise(async (res) => { + const { status_callback, hf_token, normalise_files } = options; + const return_obj = { + predict, + submit, + view_api, + // duplicate + }; + + const transform_files = normalise_files ?? true; + // if (typeof window === "undefined" || !("WebSocket" in window)) { + // const ws = await import("ws"); + // NodeBlob = (await import("node:buffer")).Blob; + // //@ts-ignore + // global.WebSocket = ws.WebSocket; + // } + + const { ws_protocol, http_protocol, host, space_id } = + await process_endpoint(app_reference, hf_token); + + const session_hash = Math.random().toString(36).substring(2); + const last_status: Record = {}; + let config: Config; + let api_map: Record = {}; + + let jwt: false | string = false; + + if (hf_token && space_id) { + jwt = await get_jwt(space_id, hf_token); + } + + async function config_success(_config: Config) { + config = _config; + api_map = map_names_to_ids(_config?.dependencies || []); + try { + api = await view_api(config); + } catch (e: any) { + console.error(`Could not get api details: ${e.message}`); + } + + return { + config, + ...return_obj, + }; + } + let api: ApiInfo; + async function handle_space_sucess(status: SpaceStatus) { + if (status_callback) status_callback(status); + if (status.status === "running") + try { + config = await resolve_config( + fetch_implementation, + `${http_protocol}//${host}`, + hf_token + ); + + const _config: any = await config_success(config); + res(_config); + } catch (e) { + console.error(e); + if (status_callback) { + status_callback({ + status: "error", + message: "Could not load this space.", + load_status: "error", + detail: "NOT_FOUND", + }); + } + } + } + + try { + config = await resolve_config( + fetch_implementation, + `${http_protocol}//${host}`, + hf_token + ); + + const _config: any = await config_success(config); + res(_config); + } catch (e) { + console.error(e); + if (space_id) { + check_space_status( + space_id, + RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain", + handle_space_sucess + ); + } else { + if (status_callback) + status_callback({ + status: "error", + message: "Could not load this space.", + load_status: "error", + detail: "NOT_FOUND", + }); + } + } + + /** + * Run a prediction. + * @param endpoint - The prediction endpoint to use. + * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates. + * @return Returns the data for the prediction or an error message. + */ + function predict( + endpoint: string, + data: unknown[], + event_data?: unknown + ) { + let data_returned = false; + let status_complete = false; + return new Promise((res, rej) => { + const app = submit(endpoint, data, event_data); + + app + .on("data", (d) => { + data_returned = true; + if (status_complete) { + app.destroy(); + } + res(d); + }) + .on("status", (status) => { + if (status.stage === "error") rej(status); + if (status.stage === "complete" && data_returned) { + app.destroy(); + } + if (status.stage === "complete") { + status_complete = true; + } + }); + }); + } + + function submit( + endpoint: string | number, + data: unknown[], + event_data?: unknown + ): SubmitReturn { + let fn_index: number; + let api_info: EndpointInfo; + + if (typeof endpoint === "number") { + fn_index = endpoint; + api_info = api.unnamed_endpoints[fn_index]; + } else { + const trimmed_endpoint = endpoint.replace(/^\//, ""); + + fn_index = api_map[trimmed_endpoint]; + api_info = api.named_endpoints[endpoint.trim()]; + } + + if (typeof fn_index !== "number") { + throw new Error( + "There is no endpoint matching that name of fn_index matching that number." + ); + } + + let websocket: WebSocket; + + const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint; + let payload: Payload; + let complete: false | Record = false; + const listener_map: ListenerMap = {}; + + handle_blob( + `${http_protocol}//${host + config.path}`, + data, + api_info, + hf_token + ).then((_payload) => { + payload = { data: _payload || [], event_data, fn_index }; + if (skip_queue(fn_index, config)) { + fire_event({ + type: "status", + endpoint: _endpoint, + stage: "pending", + queue: false, + fn_index, + time: new Date(), + }); + + post_data( + `${http_protocol}//${host + config.path}/run${ + _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}` + }`, + { + ...payload, + session_hash, + }, + hf_token + ) + .then(([output, status_code]) => { + const data = transform_files + ? transform_output( + output.data, + api_info, + config.root, + config.root_url + ) + : output.data; + if (status_code == 200) { + fire_event({ + type: "data", + endpoint: _endpoint, + fn_index, + data: data, + time: new Date(), + }); + + fire_event({ + type: "status", + endpoint: _endpoint, + fn_index, + stage: "complete", + eta: output.average_duration, + queue: false, + time: new Date(), + }); + } else { + fire_event({ + type: "status", + stage: "error", + endpoint: _endpoint, + fn_index, + message: output.error, + queue: false, + time: new Date(), + }); + } + }) + .catch((e) => { + fire_event({ + type: "status", + stage: "error", + message: e.message, + endpoint: _endpoint, + fn_index, + queue: false, + time: new Date(), + }); + }); + } else { + fire_event({ + type: "status", + stage: "pending", + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date(), + }); + + let url = new URL(`${ws_protocol}://${host}${config.path} + /queue/join`); + + if (jwt) { + url.searchParams.set("__sign", jwt); + } + + websocket = new WebSocket(url); + + websocket.onclose = (evt) => { + if (!evt.wasClean) { + fire_event({ + type: "status", + stage: "error", + message: BROKEN_CONNECTION_MSG, + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date(), + }); + } + }; + + websocket.onmessage = function (event) { + const _data = JSON.parse(event.data.toString()); + const { type, status, data } = handle_message( + _data, + last_status[fn_index] + ); + + if (type === "update" && status && !complete) { + // call 'status' listeners + fire_event({ + type: "status", + endpoint: _endpoint, + fn_index, + time: new Date(), + ...status, + }); + if (status.stage === "error") { + websocket.close(); + } + } else if (type === "hash") { + websocket.send(JSON.stringify({ fn_index, session_hash })); + return; + } else if (type === "data") { + websocket.send(JSON.stringify({ ...payload, session_hash })); + } else if (type === "complete") { + complete = status || false; + } else if (type === "generating") { + fire_event({ + type: "status", + time: new Date(), + ...status, + stage: status?.stage!, + queue: true, + endpoint: _endpoint, + fn_index, + }); + } + if (data) { + fire_event({ + type: "data", + time: new Date(), + data: transform_files + ? transform_output( + data.data, + api_info, + config.root, + config.root_url + ) + : data.data, + endpoint: _endpoint, + fn_index, + }); + + if (complete) { + fire_event({ + type: "status", + time: new Date(), + ...complete, + stage: status?.stage!, + queue: true, + endpoint: _endpoint, + fn_index, + }); + websocket.close(); + } + } + }; + + // different ws contract for gradio versions older than 3.6.0 + //@ts-ignore + if (semiver(config.version || "2.0.0", "3.6") < 0) { + addEventListener("open", () => + websocket.send(JSON.stringify({ hash: session_hash })) + ); + } + } + }); + + function fire_event(event: Event) { + const narrowed_listener_map: ListenerMap = listener_map; + const listeners = narrowed_listener_map[event.type] || []; + listeners?.forEach((l) => l(event)); + } + + function on( + eventType: K, + listener: EventListener + ) { + const narrowed_listener_map: ListenerMap = listener_map; + const listeners = narrowed_listener_map[eventType] || []; + narrowed_listener_map[eventType] = listeners; + listeners?.push(listener); + + return { on, off, cancel, destroy }; + } + + function off( + eventType: K, + listener: EventListener + ) { + const narrowed_listener_map: ListenerMap = listener_map; + let listeners = narrowed_listener_map[eventType] || []; + listeners = listeners?.filter((l) => l !== listener); + narrowed_listener_map[eventType] = listeners; + + return { on, off, cancel, destroy }; + } + + async function cancel() { + const _status: Status = { + stage: "complete", + queue: false, + time: new Date(), + }; + complete = _status; + fire_event({ + ..._status, + type: "status", + endpoint: _endpoint, + fn_index: fn_index, + }); + + if (websocket && websocket.readyState === 0) { + websocket.addEventListener("open", () => { + websocket.close(); + }); + } else { + websocket.close(); + } + + try { + await fetch_implementation( + `${http_protocol}//${host + config.path}/reset`, + { + headers: { "Content-Type": "application/json" }, + method: "POST", + body: JSON.stringify({ fn_index, session_hash }), + } + ); + } catch (e) { + console.warn( + "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable." + ); + } + } + + function destroy() { + for (const event_type in listener_map) { + listener_map[event_type as "data" | "status"]?.forEach((fn) => { + off(event_type as "data" | "status", fn); + }); + } + } + + return { + on, + off, + cancel, + destroy, + }; + } + + async function view_api(config?: Config): Promise> { + if (api) return api; + + const headers: { + Authorization?: string; + "Content-Type": "application/json"; + } = { "Content-Type": "application/json" }; + if (hf_token) { + headers.Authorization = `Bearer ${hf_token}`; + } + let response: Response; + // @ts-ignore + if (semiver(config.version || "2.0.0", "3.30") < 0) { + response = await fetch_implementation( + "https://gradio-space-api-fetcher-v2.hf.space/api", + { + method: "POST", + body: JSON.stringify({ + serialize: false, + config: JSON.stringify(config), + }), + headers, + } + ); + } else { + response = await fetch_implementation(`${config?.root}/info`, { + headers, + }); + } + + if (!response.ok) { + throw new Error(BROKEN_CONNECTION_MSG); + } + + let api_info = (await response.json()) as + | ApiInfo + | { api: ApiInfo }; + if ("api" in api_info) { + api_info = api_info.api; + } + + if ( + api_info.named_endpoints["/predict"] && + !api_info.unnamed_endpoints["0"] + ) { + api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"]; + } + + const x = transform_api_info(api_info, config as Config, api_map); + return x; + } + }); + } + + async function handle_blob( + endpoint: string, + data: unknown[], + api_info: any, + token?: `hf_${string}` + ): Promise { + const blob_refs = await walk_and_store_blobs( + data, + undefined, + [], + true, + api_info + ); + + return Promise.all( + blob_refs.map( + async ({ + path, + blob, + data, + type, + }: { + path: string; + blob: Blob; + data: string; + type: string; + }) => { + if (blob) { + const file_url = (await upload_files(endpoint, [blob], token)) + ?.files?.[0]; + return { path, file_url, type }; + } else { + return { path, base64: data, type }; + } + } + ) + ).then((r) => { + r.forEach(({ path, file_url, base64, type }) => { + if (base64) { + update_object(data, base64, path); + } else if (type === "Gallery") { + update_object(data, file_url, path); + } else if (file_url) { + const o = { + is_file: true, + name: `${file_url}`, + data: null, + // orig_name: "file.csv" + }; + update_object(data, o, path); + } + }); + + return data; + }); + } +} + +export const { post_data, upload_files, client, handle_blob } = + api_factory(fetch); + +function transform_output( + data: any[], + api_info: any, + root_url: string, + remote_url: string | null = null +): unknown[] { + return data.map((d, i) => { + if (api_info.returns?.[i]?.component === "File") { + return normalise_file(d, root_url, remote_url); + } else if (api_info.returns?.[i]?.component === "Gallery") { + return d.map((img: any) => { + return Array.isArray(img) + ? [normalise_file(img[0], root_url, remote_url), img[1]] + : [normalise_file(img, root_url, remote_url), null]; + }); + } else if (typeof d === "object" && d.is_file) { + return normalise_file(d, root_url, remote_url); + } else { + return d; + } + }); +} + +function normalise_file( + file: Array, + root: string, + root_url: string | null +): Array; +function normalise_file( + file: FileData | string, + root: string, + root_url: string | null +): FileData; +function normalise_file( + file: null, + root: string, + root_url: string | null +): null; +function normalise_file( + file: any, + root: any, + root_url: any +): Array | FileData | null { + if (file == null) return null; + if (typeof file === "string") { + return { + name: "file_data", + data: file, + }; + } else if (Array.isArray(file)) { + const normalized_file: Array = []; + + for (const x of file) { + if (x === null) { + normalized_file.push(null); + } else { + normalized_file.push(normalise_file(x, root, root_url)); + } + } + + return normalized_file as Array; + } else if (file.is_file) { + if (!root_url) { + file.data = root + "/file=" + file.name; + } else { + file.data = "/proxy=" + root_url + "file=" + file.name; + } + } + return file; +} + +interface ApiData { + label: string; + type: { + type: any; + description: string; + }; + component: string; + example_input?: any; +} + +interface JsApiData { + label: string; + type: string; + component: string; + example_input: any; +} + +interface EndpointInfo { + parameters: T[]; + returns: T[]; +} +interface ApiInfo { + named_endpoints: { + [key: string]: EndpointInfo; + }; + unnamed_endpoints: { + [key: string]: EndpointInfo; + }; +} + +function get_type( + type: { [key: string]: any }, + component: string, + serializer: string, + signature_type: "return" | "parameter" +) { + switch (type.type) { + case "string": + return "string"; + case "boolean": + return "boolean"; + case "number": + return "number"; + } + + if ( + serializer === "JSONSerializable" || + serializer === "StringSerializable" + ) { + return "any"; + } else if (serializer === "ListStringSerializable") { + return "string[]"; + } else if (component === "Image") { + return signature_type === "parameter" ? "Blob | File | Buffer" : "string"; + } else if (serializer === "FileSerializable") { + if (type?.type === "array") { + return signature_type === "parameter" + ? "(Blob | File | Buffer)[]" + : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`; + } else { + return signature_type === "parameter" + ? "Blob | File | Buffer" + : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`; + } + } else if (serializer === "GallerySerializable") { + return signature_type === "parameter" + ? "[(Blob | File | Buffer), (string | null)][]" + : `[{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}, (string | null))][]`; + } +} + +function get_description( + type: { type: any; description: string }, + serializer: string +) { + if (serializer === "GallerySerializable") { + return "array of [file, label] tuples"; + } else if (serializer === "ListStringSerializable") { + return "array of strings"; + } else if (serializer === "FileSerializable") { + return "array of files or single file"; + } else { + return type.description; + } +} + +function transform_api_info( + api_info: ApiInfo, + config: Config, + api_map: Record +): ApiInfo { + let new_data = { + named_endpoints: {}, + unnamed_endpoints: {}, + }; + let key: keyof ApiInfo; + for (key in api_info) { + const cat = api_info[key]; + + for (const endpoint in cat) { + const dep_index = config.dependencies[endpoint as any] + ? endpoint + : api_map[endpoint.replace("/", "")]; + + const info = cat[endpoint]; + // @ts-ignore + new_data[key][endpoint] = {}; + // @ts-ignore + new_data[key][endpoint].parameters = {}; + // @ts-ignore + new_data[key][endpoint].returns = {}; + // @ts-ignore + new_data[key][endpoint].type = + config.dependencies[dep_index as any].types; + // @ts-ignore + new_data[key][endpoint].parameters = info.parameters.map( + // @ts-ignore + ({ label, component, type, serializer }) => ({ + label, + component, + type: get_type(type, component, serializer, "parameter"), + description: get_description(type, serializer), + }) + ); + // @ts-ignore + new_data[key][endpoint].returns = info.returns.map( + // @ts-ignore + ({ label, component, type, serializer }) => ({ + label, + component, + type: get_type(type, component, serializer, "return"), + description: get_description(type, serializer), + }) + ); + } + } + + return new_data; +} + +async function get_jwt( + space: string, + token: `hf_${string}` +): Promise { + try { + const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); + + const jwt = (await r.json()).token; + + return jwt || false; + } catch (e) { + console.error(e); + return false; + } +} + +function update_object(object: any, newValue: any, stack: any) { + while (stack.length > 1) { + object = object[stack.shift()]; + } + + object[stack.shift()] = newValue; +} + +export async function walk_and_store_blobs( + param: any, + type = undefined, + path: any[] = [], + root = false, + api_info: any = undefined +) { + if (Array.isArray(param)) { + let blob_refs: any[] = []; + + await Promise.all( + param.map(async (v, i) => { + let new_path = path.slice(); + new_path.push(i); + + const array_refs = await walk_and_store_blobs( + param[i], + root ? api_info?.parameters[i]?.component || undefined : type, + new_path, + false, + api_info + ); + + blob_refs = blob_refs.concat(array_refs); + }) + ); + + return blob_refs; + } else if (globalThis.Buffer && param instanceof globalThis.Buffer) { + const is_image = type === "Image"; + return [ + { + path: path, + blob: is_image ? false : new NodeBlob([param]), + data: is_image ? `${param.toString("base64")}` : false, + type, + }, + ]; + } else if ( + param instanceof Blob || + (typeof window !== "undefined" && param instanceof File) + ) { + if (type === "Image") { + let data; + + if (typeof window !== "undefined") { + // browser + data = await image_to_data_uri(param); + } else { + const buffer = await param.arrayBuffer(); + data = Buffer.from(buffer).toString("base64"); + } + + return [{ path, data, type }]; + } else { + return [{ path: path, blob: param, type }]; + } + } else if (typeof param === "object") { + let blob_refs: any[] = []; + for (let key in param) { + if (param.hasOwnProperty(key)) { + let new_path: any[] = path.slice(); + new_path.push(key); + blob_refs = blob_refs.concat( + await walk_and_store_blobs( + param[key], + undefined, + new_path, + false, + api_info + ) + ); + } + } + return blob_refs; + } else { + return []; + } +} + +function image_to_data_uri(blob: Blob) { + return new Promise((resolve, _) => { + const reader = new FileReader(); + reader.onloadend = () => resolve(reader.result); + reader.readAsDataURL(blob); + }); +} + +function skip_queue(id: number, config: Config) { + return ( + !(config?.dependencies?.[id]?.queue === null + ? config.enable_queue + : config?.dependencies?.[id]?.queue) || false + ); +} + +async function resolve_config( + fetch_implementation: typeof fetch, + endpoint?: string, + token?: `hf_${string}` +): Promise { + const headers: { Authorization?: string } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + if ( + typeof window !== "undefined" && + window.gradio_config && + location.origin !== "http://localhost:9876" + ) { + const path = window.gradio_config.root; + const config = window.gradio_config; + config.root = endpoint + config.root; + return { ...config, path: path }; + } else if (endpoint) { + let response = await fetch_implementation(`${endpoint}/config`, { + headers, + }); + + if (response.status === 200) { + const config = await response.json(); + config.path = config.path ?? ""; + config.root = endpoint; + return config; + } else { + throw new Error("Could not get config."); + } + } + + throw new Error("No config or app endpoint found"); +} + +async function check_space_status( + id: string, + type: "subdomain" | "space_name", + status_callback: SpaceStatusCallback +) { + let endpoint = + type === "subdomain" + ? `https://huggingface.co/api/spaces/by-subdomain/${id}` + : `https://huggingface.co/api/spaces/${id}`; + let response; + let _status; + try { + response = await fetch(endpoint); + _status = response.status; + if (_status !== 200) { + throw new Error(); + } + response = await response.json(); + } catch (e) { + status_callback({ + status: "error", + load_status: "error", + message: "Could not get space status", + detail: "NOT_FOUND", + }); + return; + } + + if (!response || _status !== 200) return; + const { + runtime: { stage }, + id: space_name, + } = response; + + switch (stage) { + case "STOPPED": + case "SLEEPING": + status_callback({ + status: "sleeping", + load_status: "pending", + message: "Space is asleep. Waking it up...", + detail: stage, + }); + + setTimeout(() => { + check_space_status(id, type, status_callback); + }, 1000); // poll for status + break; + case "PAUSED": + status_callback({ + status: "paused", + load_status: "error", + message: + "This space has been paused by the author. If you would like to try this demo, consider duplicating the space.", + detail: stage, + discussions_enabled: await discussions_enabled(space_name), + }); + break; + case "RUNNING": + case "RUNNING_BUILDING": + status_callback({ + status: "running", + load_status: "complete", + message: "", + detail: stage, + }); + // load_config(source); + // launch + break; + case "BUILDING": + status_callback({ + status: "building", + load_status: "pending", + message: "Space is building...", + detail: stage, + }); + + setTimeout(() => { + check_space_status(id, type, status_callback); + }, 1000); + break; + default: + status_callback({ + status: "space_error", + load_status: "error", + message: "This space is experiencing an issue.", + detail: stage, + discussions_enabled: await discussions_enabled(space_name), + }); + break; + } +} + +function handle_message( + data: any, + last_status: Status["stage"] +): { + type: "hash" | "data" | "update" | "complete" | "generating" | "none"; + data?: any; + status?: Status; +} { + const queue = true; + switch (data.msg) { + case "send_data": + return { type: "data" }; + case "send_hash": + return { type: "hash" }; + case "queue_full": + return { + type: "update", + status: { + queue, + message: QUEUE_FULL_MSG, + stage: "error", + code: data.code, + success: data.success, + }, + }; + case "estimation": + return { + type: "update", + status: { + queue, + stage: last_status || "pending", + code: data.code, + size: data.queue_size, + position: data.rank, + eta: data.rank_eta, + success: data.success, + }, + }; + case "progress": + return { + type: "update", + status: { + queue, + stage: "pending", + code: data.code, + progress_data: data.progress_data, + success: data.success, + }, + }; + case "process_generating": + return { + type: "generating", + status: { + queue, + message: !data.success ? data.output.error : null, + stage: data.success ? "generating" : "error", + code: data.code, + progress_data: data.progress_data, + eta: data.average_duration, + }, + data: data.success ? data.output : null, + }; + case "process_completed": + if ("error" in data.output) { + return { + type: "update", + status: { + queue, + message: data.output.error as string, + stage: "error", + code: data.code, + success: data.success, + }, + }; + } else { + return { + type: "complete", + status: { + queue, + message: !data.success ? data.output.error : undefined, + stage: data.success ? "complete" : "error", + code: data.code, + progress_data: data.progress_data, + eta: data.output.average_duration, + }, + data: data.success ? data.output : null, + }; + } + + case "process_starts": + return { + type: "update", + status: { + queue, + stage: "pending", + code: data.code, + size: data.rank, + position: 0, + success: data.success, + }, + }; + } + + return { type: "none", status: { stage: "error", queue } }; +} diff --git a/src/gradio/globals.d.ts b/src/gradio/globals.d.ts new file mode 100644 index 0000000..9abbf76 --- /dev/null +++ b/src/gradio/globals.d.ts @@ -0,0 +1,31 @@ +declare global { + interface Window { + __gradio_mode__: "app" | "website"; + launchGradio: Function; + launchGradioFromSpaces: Function; + gradio_config: Config; + scoped_css_attach: (link: HTMLLinkElement) => void; + __is_colab__: boolean; + } +} + +export interface Config { + auth_required: boolean | undefined; + auth_message: string; + components: any[]; + css: string | null; + dependencies: any[]; + dev_mode: boolean; + enable_queue: boolean; + layout: any; + mode: "blocks" | "interface"; + root: string; + theme: string; + title: string; + version: string; + space_id: string | null; + is_colab: boolean; + show_api: boolean; + stylesheets: string[]; + path: string; +} diff --git a/src/gradio/index.ts b/src/gradio/index.ts new file mode 100644 index 0000000..6061207 --- /dev/null +++ b/src/gradio/index.ts @@ -0,0 +1,8 @@ +export { + client, + post_data, + upload_files, + duplicate, + api_factory, +} from "./client"; +export type { SpaceStatus } from "./types"; diff --git a/src/gradio/types.ts b/src/gradio/types.ts new file mode 100644 index 0000000..9e283cf --- /dev/null +++ b/src/gradio/types.ts @@ -0,0 +1,109 @@ +export interface Config { + auth_required: boolean | undefined; + auth_message: string; + components: any[]; + css: string | null; + dependencies: any[]; + dev_mode: boolean; + enable_queue: boolean; + layout: any; + mode: "blocks" | "interface"; + root: string; + root_url?: string; + theme: string; + title: string; + version: string; + space_id: string | null; + is_colab: boolean; + show_api: boolean; + stylesheets: string[]; + path: string; +} + +export interface Payload { + data: Array; + fn_index?: number; + event_data?: unknown; + time?: Date; +} + +export interface PostResponse { + error?: string; + [x: string]: any; +} +export interface UploadResponse { + error?: string; + files?: Array; +} + +export interface Status { + queue: boolean; + code?: string; + success?: boolean; + stage: "pending" | "error" | "complete" | "generating"; + size?: number; + position?: number; + eta?: number; + message?: string; + progress_data?: Array<{ + progress: number | null; + index: number | null; + length: number | null; + unit: string | null; + desc: string | null; + }>; + time?: Date; +} + +export interface SpaceStatusNormal { + status: "sleeping" | "running" | "building" | "error" | "stopped"; + detail: + | "SLEEPING" + | "RUNNING" + | "RUNNING_BUILDING" + | "BUILDING" + | "NOT_FOUND"; + load_status: "pending" | "error" | "complete" | "generating"; + message: string; +} +export interface SpaceStatusError { + status: "space_error" | "paused"; + detail: + | "NO_APP_FILE" + | "CONFIG_ERROR" + | "BUILD_ERROR" + | "RUNTIME_ERROR" + | "PAUSED"; + load_status: "error"; + message: string; + discussions_enabled: boolean; +} +export type SpaceStatus = SpaceStatusNormal | SpaceStatusError; + +export type status_callback_function = (a: Status) => void; +export type SpaceStatusCallback = (a: SpaceStatus) => void; + +export type EventType = "data" | "status"; + +export interface EventMap { + data: Payload; + status: Status; +} + +export type Event = { + [P in K]: EventMap[P] & { type: P; endpoint: string; fn_index: number }; +}[K]; +export type EventListener = (event: Event) => void; +export type ListenerMap = { + [P in K]?: EventListener[]; +}; +export interface FileData { + name: string; + orig_name?: string; + size?: number; + data: string; + blob?: File; + is_file?: boolean; + mime_type?: string; + alt_text?: string; +} diff --git a/src/gradio/utils.ts b/src/gradio/utils.ts new file mode 100644 index 0000000..7340081 --- /dev/null +++ b/src/gradio/utils.ts @@ -0,0 +1,211 @@ +import type { Config } from "./types"; + +export function determine_protocol(endpoint: string): { + ws_protocol: "ws" | "wss"; + http_protocol: "http:" | "https:"; + host: string; +} { + if (endpoint.startsWith("http")) { + const { protocol, host } = new URL(endpoint); + + if (host.endsWith("hf.space")) { + return { + ws_protocol: "wss", + host: host, + http_protocol: protocol as "http:" | "https:", + }; + } else { + return { + ws_protocol: protocol === "https:" ? "wss" : "ws", + http_protocol: protocol as "http:" | "https:", + host, + }; + } + } + + // default to secure if no protocol is provided + return { + ws_protocol: "wss", + http_protocol: "https:", + host: endpoint, + }; +} + +export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/; +export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/; +export async function process_endpoint( + app_reference: string, + token?: `hf_${string}` +): Promise<{ + space_id: string | false; + host: string; + ws_protocol: "ws" | "wss"; + http_protocol: "http:" | "https:"; +}> { + const headers: { Authorization?: string } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + + const _app_reference = app_reference.trim(); + + if (RE_SPACE_NAME.test(_app_reference)) { + try { + const res = await fetch( + `https://huggingface.co/api/spaces/${_app_reference}/host`, + { headers } + ); + + if (res.status !== 200) + throw new Error("Space metadata could not be loaded."); + const _host = (await res.json()).host; + + return { + space_id: app_reference, + ...determine_protocol(_host), + }; + } catch (e: any) { + throw new Error("Space metadata could not be loaded." + e.message); + } + } + + if (RE_SPACE_DOMAIN.test(_app_reference)) { + const { ws_protocol, http_protocol, host } = + determine_protocol(_app_reference); + + return { + space_id: host.replace(".hf.space", ""), + ws_protocol, + http_protocol, + host, + }; + } + + return { + space_id: false, + ...determine_protocol(_app_reference), + }; +} + +export function map_names_to_ids(fns: Config["dependencies"]) { + let apis: Record = {}; + + fns.forEach(({ api_name }, i) => { + if (api_name) apis[api_name] = i; + }); + + return apis; +} + +const RE_DISABLED_DISCUSSION = + /^(?=[^]*\b[dD]iscussions{0,1}\b)(?=[^]*\b[dD]isabled\b)[^]*$/; +export async function discussions_enabled(space_id: string) { + try { + const r = await fetch( + `https://huggingface.co/api/spaces/${space_id}/discussions`, + { + method: "HEAD", + } + ); + const error = r.headers.get("x-error-message"); + + if (error && RE_DISABLED_DISCUSSION.test(error)) return false; + else return true; + } catch (e) { + return false; + } +} + +export async function get_space_hardware( + space_id: string, + token: `hf_${string}` +) { + const headers: { Authorization?: string } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + + try { + const res = await fetch( + `https://huggingface.co/api/spaces/${space_id}/runtime`, + { headers } + ); + + if (res.status !== 200) + throw new Error("Space hardware could not be obtained."); + + const { hardware } = await res.json(); + + return hardware; + } catch (e: any) { + throw new Error(e.message); + } +} + +export async function set_space_hardware( + space_id: string, + new_hardware: (typeof hardware_types)[number], + token: `hf_${string}` +) { + const headers: { Authorization?: string } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + + try { + const res = await fetch( + `https://huggingface.co/api/spaces/${space_id}/hardware`, + { headers, body: JSON.stringify(new_hardware) } + ); + + if (res.status !== 200) + throw new Error( + "Space hardware could not be set. Please ensure the space hardware provided is valid and that a Hugging Face token is passed in." + ); + + const { hardware } = await res.json(); + + return hardware; + } catch (e: any) { + throw new Error(e.message); + } +} + +export async function set_space_timeout( + space_id: string, + timeout: number, + token: `hf_${string}` +) { + const headers: { Authorization?: string } = {}; + if (token) { + headers.Authorization = `Bearer ${token}`; + } + + try { + const res = await fetch( + `https://huggingface.co/api/spaces/${space_id}/hardware`, + { headers, body: JSON.stringify({ seconds: timeout }) } + ); + + if (res.status !== 200) + throw new Error( + "Space hardware could not be set. Please ensure the space hardware provided is valid and that a Hugging Face token is passed in." + ); + + const { hardware } = await res.json(); + + return hardware; + } catch (e: any) { + throw new Error(e.message); + } +} + +export const hardware_types = [ + "cpu-basic", + "cpu-upgrade", + "t4-small", + "t4-medium", + "a10g-small", + "a10g-large", + "a100-large", +] as const; diff --git a/src/index.ts b/src/index.ts index a5f57f8..6da0ecf 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,3 +6,4 @@ export * from "./midjourne.api"; export * from "./command"; export * from "./verify.human"; export * from "./banned.words"; +export * from "./face.swap"; diff --git a/src/interfaces/message.ts b/src/interfaces/message.ts index f4730de..c18638f 100644 --- a/src/interfaces/message.ts +++ b/src/interfaces/message.ts @@ -47,6 +47,14 @@ export interface MJSettings { flags: number; options: MJOptions[]; } +export interface MJDescribe { + id: string; + flags: number; + uri: string; + proxy_url?: string; + options: MJOptions[]; + descriptions: string[]; +} export interface MJShorten { description: string; diff --git a/src/midjourne.api.ts b/src/midjourne.api.ts index a6344bb..80082b6 100644 --- a/src/midjourne.api.ts +++ b/src/midjourne.api.ts @@ -359,7 +359,7 @@ export class MidjourneyApi extends Command { return resp; } - async UploadImageByBole(blob: Blob, filename = "image.png") { + async UploadImageByBole(blob: Blob, filename = nextNonce() + ".png") { const fileData = await blob.arrayBuffer(); const mimeType = blob.type; const file_size = fileData.byteLength; @@ -435,4 +435,30 @@ export class MidjourneyApi extends Command { const payload = await this.describePayload(image, nonce); return this.safeIteractions(payload); } + async upImageApi(image: DiscordImage, nonce?: string) { + const { SalaiToken, DiscordBaseUrl, ChannelId, fetch } = this.config; + const payload = { + content: "", + nonce, + channel_id: ChannelId, + type: 0, + sticker_ids: [], + attachments: [image], + }; + + const url = new URL( + `${DiscordBaseUrl}/api/v9/channels/${ChannelId}/messages` + ); + const headers = { + Authorization: SalaiToken, + "content-type": "application/json", + }; + const response = await fetch(url, { + headers, + method: "POST", + body: JSON.stringify(payload), + }); + + return response.status; + } } diff --git a/src/midjourney.ts b/src/midjourney.ts index 0c080ac..5e9414c 100644 --- a/src/midjourney.ts +++ b/src/midjourney.ts @@ -6,8 +6,15 @@ import { } from "./interfaces"; import { MidjourneyApi } from "./midjourne.api"; import { MidjourneyMessage } from "./discord.message"; -import { toRemixCustom, custom2Type, nextNonce, random } from "./utls"; +import { + toRemixCustom, + custom2Type, + nextNonce, + random, + base64ToBlob, +} from "./utls"; import { WsMessage } from "./discord.ws"; +import { faceSwap } from "./face.swap"; export class Midjourney extends MidjourneyMessage { public config: MJConfig; private wsClient?: WsMessage; @@ -156,6 +163,7 @@ export class Midjourney extends MidjourneyMessage { const wsClient = await this.getWsClient(); const nonce = nextNonce(); const DcImage = await this.MJApi.UploadImageByUri(imgUri); + this.log(`Describe`, DcImage); const httpStatus = await this.MJApi.DescribeApi(DcImage, nonce); if (httpStatus !== 204) { throw new Error(`DescribeApi failed with status ${httpStatus}`); @@ -361,6 +369,24 @@ export class Midjourney extends MidjourneyMessage { }); } + async FaceSwap(target: string, source: string) { + const wsClient = await this.getWsClient(); + const app = new faceSwap(this.config.HuggingFaceToken); + const Target = await (await this.config.fetch(target)).blob(); + const Source = await (await this.config.fetch(source)).blob(); + const res = await app.changeFace(Target, Source); + this.log(res[0]); + const blob = await base64ToBlob(res[0] as string); + const DcImage = await this.MJApi.UploadImageByBole(blob); + const nonce = nextNonce(); + const httpStatus = await this.MJApi.DescribeApi(DcImage, nonce); + if (httpStatus !== 204) { + throw new Error(`DescribeApi failed with status ${httpStatus}`); + } + const describe = await wsClient.waitDescribe(nonce); + return describe?.uri; + } + Close() { if (this.wsClient) { this.wsClient.close(); diff --git a/src/utls/index.ts b/src/utls/index.ts index c61f4ff..f912e4f 100644 --- a/src/utls/index.ts +++ b/src/utls/index.ts @@ -152,3 +152,24 @@ export const toRemixCustom = (customID: string) => { const convertedString = `MJ::RemixModal::${parts[4]}::${parts[3]}::1`; return convertedString; }; + +export async function base64ToBlob(base64Image: string): Promise { + // 移除 base64 图像头部信息 + const base64Data = base64Image.replace( + /^data:image\/(png|jpeg|jpg);base64,/, + "" + ); + + // 将 base64 数据解码为二进制数据 + const binaryData = atob(base64Data); + + // 创建一个 Uint8Array 来存储二进制数据 + const arrayBuffer = new ArrayBuffer(binaryData.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < binaryData.length; i++) { + uint8Array[i] = binaryData.charCodeAt(i); + } + + // 使用 Uint8Array 创建 Blob 对象 + return new Blob([uint8Array], { type: "image/png" }); // 替换为相应的 MIME 类型 +} diff --git a/test/face.ts b/test/face.ts new file mode 100644 index 0000000..3e56998 --- /dev/null +++ b/test/face.ts @@ -0,0 +1,25 @@ +import "dotenv/config"; +import { faceSwap } from "../src"; +/** + * + * ``` + * npx tsx test/face.ts + * ``` + */ + +async function test2() { + const app = new faceSwap(process.env.HuggingFaceToken); + const Target = await ( + await fetch( + "https://cdn.discordapp.com/attachments/1108587422389899304/1129321837042602016/guapitu006_a_girls_face_with_david_bowies_thunderbolt_71ee5899-bd45-4fc4-8c9d-92f19ddb0a03.png" + ) + ).blob(); + const Source = await ( + await fetch( + "https://cdn.discordapp.com/attachments/1108587422389899304/1129321826804306031/guapitu006_Cute_warrior_girl_in_the_style_of_Baten_Kaitos__111f39bc-329e-4fab-9af7-ee219fedf260.png" + ) + ).blob(); + + await app.changeFace(Target, Source); +} +test2(); diff --git a/yarn.lock b/yarn.lock index d3ba752..ed5462e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -296,6 +296,11 @@ prettier@^2.8.8: resolved "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz" integrity sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q== +semiver@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/semiver/-/semiver-1.1.0.tgz#9c97fb02c21c7ce4fcf1b73e2c7a24324bdddd5f" + integrity sha512-QNI2ChmuioGC1/xjyYwyZYADILWyW6AmS1UH6gDj/SFUUUS4MBAWs/7mxnkRPc/F4iHezDP+O8t0dO8WHiEOdg== + snowyflake@^2.0.0: version "2.0.0" resolved "https://registry.npmjs.org/snowyflake/-/snowyflake-2.0.0.tgz"