diff --git a/.github/workflows/releases-4.2.yaml b/.github/workflows/v4.3.x.yaml
similarity index 93%
rename from .github/workflows/releases-4.2.yaml
rename to .github/workflows/v4.3.x.yaml
index 2431cb3..a658e07 100644
--- a/.github/workflows/releases-4.2.yaml
+++ b/.github/workflows/v4.3.x.yaml
@@ -1,6 +1,6 @@
-name: Package
+name: verson 4.3.x
env:
- APPVERSION: 4.2.${{ github.run_number }}
+ APPVERSION: 4.3.${{ github.run_number }}
on:
workflow_dispatch:
push:
diff --git a/README.md b/README.md
index 69df071..f525cb4 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# midjourney-api
-Node.js client for the unofficial MidJourney api.
-Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts) & [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts) & [remix mode](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts)
+Node.js client for the unofficial MidJourney api.
+
@@ -9,6 +9,11 @@ Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/
+## What's new
+
+- [face swap](https://github.com/erictik/midjourney-api/blob/main/example/faceswap.ts)
+- Major update New [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts) & [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts) & [remix mode](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts)
+
## Install
```bash
@@ -21,40 +26,42 @@ yarn add midjourney
```typescript
import { Midjourney } from "midjourney";
- const client = new Midjourney({
- ServerId: process.env.SERVER_ID,
- ChannelId: process.env.CHANNEL_ID,
- SalaiToken: process.env.SALAI_TOKEN,
- Debug: true,
- Ws:true,
- });
- await client.Connect();
- const Imagine = await client.Imagine("A little pink elephant", (uri: string, progress:string) => {
- console.log("Imagine", uri, "progress", progress);
- });
- console.log({ Imagine });
-
- const Variation = await client.Variation({
- index: 2,
- msgId: Imagine.id,
- hash: Imagine.hash,
- flags: Imagine.flags,
- loading: (uri: string, progress: string) => {
- console.log("Variation.loading", uri, "progress", progress);
- },
- });
- console.log({ Variation });
- const Upscale = await client.Upscale({
- index: 2,
- msgId: Variation.id,
- hash: Variation.hash,
- flags: Variation.flags,
- loading: (uri: string, progress: string) => {
- console.log("Upscale.loading", uri, "progress", progress);
- },
- });
- console.log({ Upscale });
-
+const client = new Midjourney({
+ ServerId: process.env.SERVER_ID,
+ ChannelId: process.env.CHANNEL_ID,
+ SalaiToken: process.env.SALAI_TOKEN,
+ Debug: true,
+ Ws: true,
+});
+await client.Connect();
+const Imagine = await client.Imagine(
+ "A little pink elephant",
+ (uri: string, progress: string) => {
+ console.log("Imagine", uri, "progress", progress);
+ }
+);
+console.log({ Imagine });
+
+const Variation = await client.Variation({
+ index: 2,
+ msgId: Imagine.id,
+ hash: Imagine.hash,
+ flags: Imagine.flags,
+ loading: (uri: string, progress: string) => {
+ console.log("Variation.loading", uri, "progress", progress);
+ },
+});
+console.log({ Variation });
+const Upscale = await client.Upscale({
+ index: 2,
+ msgId: Variation.id,
+ hash: Variation.hash,
+ flags: Variation.flags,
+ loading: (uri: string, progress: string) => {
+ console.log("Upscale.loading", uri, "progress", progress);
+ },
+});
+console.log({ Upscale });
```
## Example
@@ -77,12 +84,12 @@ npm install
```
3. set the environment variables
- - [How to get your Discord TOKEN:](https://www.androidauthority.com/get-discord-token-3149920/)
- - [Create a server](https://discord.com/blog/starting-your-first-discord-server) and [Invite Midjourney Bot to Your Server](https://docs.midjourney.com/docs/invite-the-bot)
- OR [Join my discord server](https://discord.com/invite/GavuGHQbV4)
- - How to get server and channel ids:
- when you click on a channel in your server in the browser expect to have the follow URL pattern `https://discord.com/channels/$SERVER_ID/$CHANNEL_ID`
+- [How to get your Discord TOKEN:](https://www.androidauthority.com/get-discord-token-3149920/)
+- [Create a server](https://discord.com/blog/starting-your-first-discord-server) and [Invite Midjourney Bot to Your Server](https://docs.midjourney.com/docs/invite-the-bot)
+ OR [Join my discord server](https://discord.com/invite/GavuGHQbV4)
+- How to get server and channel ids:
+ when you click on a channel in your server in the browser expect to have the follow URL pattern `https://discord.com/channels/$SERVER_ID/$CHANNEL_ID`
```bash
#example variables, please set up yours
@@ -99,12 +106,13 @@ npx tsx example/imagine-ws.ts
```
## route-map
+
- [x] `/imagine` `variation` `upscale` `reroll` `blend` `zoomout` `vary`
- [x] `/info`
- [x] `/fast ` and `/relax `
-- [x] [`/prefer remix`](https://github.com/erictik/midjourney-api/blob/main/example/prefer-remix.ts)
+- [x] [`/prefer remix`](https://github.com/erictik/midjourney-api/blob/main/example/prefer-remix.ts)
- [x] [`variation (remix mode)`](https://github.com/erictik/midjourney-api/blob/main/example/variation-ws.ts)
-- [x] `/describe`
+- [x] `/describe`
- [x] [`/shorten`](https://github.com/erictik/midjourney-api/blob/main/example/shorten.ts)
- [x] `/settings` `reset`
- [x] verify human
@@ -112,13 +120,22 @@ npx tsx example/imagine-ws.ts
- [x] [niji bot](https://github.com/erictik/midjourney-api/blob/main/example/imagine-niji.ts)
- [x] [custom zoom](https://github.com/erictik/midjourney-api/blob/main/example/customzoom.ts)
- [x] autoload command payload
+
---
+
## Projects
+
- [midjourney-ui](https://github.com/erictik/midjourney-ui/)
- [midjourney-discord](https://github.com/erictik/midjourney-discord)-bot
- [phrame](https://github.com/jakowenko/phrame)
- [guapitu](https://www.guapitu.com/zh/draw?code=RRXQNF)
+
---
+
+## Support
+
+
## Star History
+
[![Star History Chart](https://api.star-history.com/svg?repos=erictik/midjourney-api&type=Date)](https://star-history.com/#erictik/midjourney-api&Date)
diff --git a/example/faceswap.ts b/example/faceswap.ts
new file mode 100644
index 0000000..309df18
--- /dev/null
+++ b/example/faceswap.ts
@@ -0,0 +1,33 @@
+import "dotenv/config";
+import { Midjourney, detectBannedWords } from "../src";
+/**
+ *
+ * a simple example of how to use faceSwap
+ * ```
+ * npx tsx example/faceswap.ts
+ * ```
+ */
+async function main() {
+ const source = `https://cdn.discordapp.com/attachments/1107965981839605792/1129362418775113789/3829c5d7-3e7e-473c-9c7b-b858e3ec97bc.jpeg`;
+ // const source = `https://cdn.discordapp.com/attachments/1108587422389899304/1129321826804306031/guapitu006_Cute_warrior_girl_in_the_style_of_Baten_Kaitos__111f39bc-329e-4fab-9af7-ee219fedf260.png`;
+ const target = `https://cdn.discordapp.com/attachments/1108587422389899304/1129321837042602016/guapitu006_a_girls_face_with_david_bowies_thunderbolt_71ee5899-bd45-4fc4-8c9d-92f19ddb0a03.png`;
+ const client = new Midjourney({
+ ServerId: process.env.SERVER_ID,
+ ChannelId: process.env.CHANNEL_ID,
+ SalaiToken: process.env.SALAI_TOKEN,
+ Debug: true,
+ HuggingFaceToken: process.env.HUGGINGFACE_TOKEN,
+ });
+
+ const info = await client.FaceSwap(target, source);
+ console.log(info);
+}
+main()
+ .then(() => {
+ console.log("finished");
+ process.exit(0);
+ })
+ .catch((err) => {
+ console.error(err);
+ process.exit(1);
+ });
diff --git a/package.json b/package.json
index 8997014..f915e85 100644
--- a/package.json
+++ b/package.json
@@ -47,6 +47,7 @@
"@huggingface/inference": "^2.5.0",
"async": "^3.2.4",
"isomorphic-ws": "^5.0.0",
+ "semiver": "^1.1.0",
"snowyflake": "^2.0.0",
"tslib": "^2.5.0",
"ws": "^8.13.0"
diff --git a/src/discord.ws.ts b/src/discord.ws.ts
index 04be61a..fec1c16 100644
--- a/src/discord.ws.ts
+++ b/src/discord.ws.ts
@@ -9,6 +9,7 @@ import {
MJOptions,
OnModal,
MJShorten,
+ MJDescribe,
} from "./interfaces";
import { MidjourneyApi } from "./midjourne.api";
import {
@@ -201,10 +202,16 @@ export class WsMessage {
this.emit("settings", message);
return;
case "describe":
- this.emitMJ(id, {
+ // console.log("describe", "meseesage", message);
+ const describe: MJDescribe = {
+ id: id,
+ flags: message.flags,
descriptions: embeds?.[0]?.description.split("\n\n"),
+ uri: embeds?.[0]?.image?.url,
+ proxy_url: embeds?.[0]?.image?.proxy_url,
options: formatOptions(components),
- });
+ };
+ this.emitMJ(id, describe);
break;
case "prefer remix":
if (content != "") {
@@ -606,10 +613,7 @@ export class WsMessage {
});
}
async waitDescribe(nonce: string) {
- return new Promise<{
- options: MJOptions[];
- descriptions: string[];
- } | null>((resolve) => {
+ return new Promise((resolve) => {
this.onceMJ(nonce, (message) => {
resolve(message);
});
diff --git a/src/face.swap.ts b/src/face.swap.ts
new file mode 100644
index 0000000..2b99709
--- /dev/null
+++ b/src/face.swap.ts
@@ -0,0 +1,24 @@
+import { client } from "./gradio/index";
+
+export class faceSwap {
+ public hf_token?: string;
+ constructor(hf_token?: string) {
+ this.hf_token = hf_token;
+ }
+ async changeFace(Target: Blob, Source: Blob) {
+ const app = await client("https://felixrosberg-face-swap.hf.space/", {
+ hf_token: this.hf_token as any,
+ });
+ // console.log("app", app);
+ const result: any = await app.predict(1, [
+ Target, // blob in 'Target' Image component
+ Source, // blob in 'Source' Image component
+ 0, // number (numeric value between 0 and 100) in 'Anonymization ratio (%)' Slider component
+ 0, // number (numeric value between 0 and 100) in 'Adversarial defense ratio (%)' Slider component
+ "Compare", // string[] (array of strings) in 'Mode' Checkboxgroup component
+ ]);
+ // result.data;
+ return result.data;
+ // console.log(result.data[0]);
+ }
+}
diff --git a/src/gradio/client.ts b/src/gradio/client.ts
new file mode 100644
index 0000000..03f4a57
--- /dev/null
+++ b/src/gradio/client.ts
@@ -0,0 +1,1336 @@
+import semiver from "semiver";
+
+import {
+ process_endpoint,
+ RE_SPACE_NAME,
+ map_names_to_ids,
+ discussions_enabled,
+ get_space_hardware,
+ set_space_hardware,
+ set_space_timeout,
+ hardware_types,
+} from "./utils";
+
+import type {
+ EventType,
+ EventListener,
+ ListenerMap,
+ Event,
+ Payload,
+ PostResponse,
+ UploadResponse,
+ Status,
+ SpaceStatus,
+ SpaceStatusCallback,
+ FileData,
+} from "./types";
+
+import type { Config } from "./types";
+import WebSocket from "isomorphic-ws";
+
+type event = (
+ eventType: K,
+ listener: EventListener
+) => SubmitReturn;
+type predict = (
+ endpoint: string | number,
+ data?: unknown[],
+ event_data?: unknown
+) => Promise;
+
+type client_return = {
+ predict: predict;
+ config: Config;
+ submit: (
+ endpoint: string | number,
+ data?: unknown[],
+ event_data?: unknown
+ ) => SubmitReturn;
+ view_api: (c?: Config) => Promise>;
+};
+
+type SubmitReturn = {
+ on: event;
+ off: event;
+ cancel: () => Promise;
+ destroy: () => void;
+};
+
+const QUEUE_FULL_MSG = "This application is too busy. Keep trying!";
+const BROKEN_CONNECTION_MSG = "Connection errored out.";
+
+export let NodeBlob: Blob;
+
+export async function duplicate(
+ app_reference: string,
+ options: {
+ hf_token: `hf_${string}`;
+ private?: boolean;
+ status_callback: SpaceStatusCallback;
+ hardware?: (typeof hardware_types)[number];
+ timeout?: number;
+ }
+) {
+ const { hf_token, private: _private, hardware, timeout } = options;
+
+ if (hardware && !hardware_types.includes(hardware)) {
+ throw new Error(
+ `Invalid hardware type provided. Valid types are: ${hardware_types
+ .map((v) => `"${v}"`)
+ .join(",")}.`
+ );
+ }
+ const headers = {
+ Authorization: `Bearer ${hf_token}`,
+ };
+
+ const user = (
+ await (
+ await fetch(`https://huggingface.co/api/whoami-v2`, {
+ headers,
+ })
+ ).json()
+ ).name;
+
+ const space_name = app_reference.split("/")[1];
+ const body: {
+ repository: string;
+ private?: boolean;
+ } = {
+ repository: `${user}/${space_name}`,
+ };
+
+ if (_private) {
+ body.private = true;
+ }
+
+ try {
+ const response = await fetch(
+ `https://huggingface.co/api/spaces/${app_reference}/duplicate`,
+ {
+ method: "POST",
+ headers: { "Content-Type": "application/json", ...headers },
+ body: JSON.stringify(body),
+ }
+ );
+
+ if (response.status === 409) {
+ return client(`${user}/${space_name}`, options);
+ } else {
+ const duplicated_space = await response.json();
+
+ let original_hardware;
+
+ if (!hardware) {
+ original_hardware = await get_space_hardware(app_reference, hf_token);
+ }
+
+ const requested_hardware = hardware || original_hardware || "cpu-basic";
+ await set_space_hardware(
+ `${user}/${space_name}`,
+ requested_hardware,
+ hf_token
+ );
+
+ await set_space_timeout(
+ `${user}/${space_name}`,
+ timeout || 300,
+ hf_token
+ );
+ return client(duplicated_space.url, options);
+ }
+ } catch (e: any) {
+ throw new Error(e);
+ }
+}
+
+/**
+ * We need to inject a customized fetch implementation for the Wasm version.
+ */
+export function api_factory(fetch_implementation: typeof fetch) {
+ return { post_data, upload_files, client, handle_blob };
+
+ async function post_data(
+ url: string,
+ body: unknown,
+ token?: `hf_${string}`
+ ): Promise<[PostResponse, number]> {
+ const headers: {
+ Authorization?: string;
+ "Content-Type": "application/json";
+ } = { "Content-Type": "application/json" };
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+ try {
+ var response = await fetch_implementation(url, {
+ method: "POST",
+ body: JSON.stringify(body),
+ headers,
+ });
+ } catch (e) {
+ return [{ error: BROKEN_CONNECTION_MSG }, 500];
+ }
+ const output: PostResponse = await response.json();
+ return [output, response.status];
+ }
+
+ async function upload_files(
+ root: string,
+ files: Array,
+ token?: `hf_${string}`
+ ): Promise {
+ const headers: {
+ Authorization?: string;
+ } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+
+ const formData = new FormData();
+ files.forEach((file) => {
+ formData.append("files", file);
+ });
+ try {
+ var response = await fetch_implementation(`${root}/upload`, {
+ method: "POST",
+ body: formData,
+ headers,
+ });
+ } catch (e) {
+ return { error: BROKEN_CONNECTION_MSG };
+ }
+ const output: UploadResponse["files"] = await response.json();
+ return { files: output };
+ }
+
+ async function client(
+ app_reference: string,
+ options: {
+ hf_token?: `hf_${string}`;
+ status_callback?: SpaceStatusCallback;
+ normalise_files?: boolean;
+ } = { normalise_files: true }
+ ): Promise {
+ return new Promise(async (res) => {
+ const { status_callback, hf_token, normalise_files } = options;
+ const return_obj = {
+ predict,
+ submit,
+ view_api,
+ // duplicate
+ };
+
+ const transform_files = normalise_files ?? true;
+ // if (typeof window === "undefined" || !("WebSocket" in window)) {
+ // const ws = await import("ws");
+ // NodeBlob = (await import("node:buffer")).Blob;
+ // //@ts-ignore
+ // global.WebSocket = ws.WebSocket;
+ // }
+
+ const { ws_protocol, http_protocol, host, space_id } =
+ await process_endpoint(app_reference, hf_token);
+
+ const session_hash = Math.random().toString(36).substring(2);
+ const last_status: Record = {};
+ let config: Config;
+ let api_map: Record = {};
+
+ let jwt: false | string = false;
+
+ if (hf_token && space_id) {
+ jwt = await get_jwt(space_id, hf_token);
+ }
+
+ async function config_success(_config: Config) {
+ config = _config;
+ api_map = map_names_to_ids(_config?.dependencies || []);
+ try {
+ api = await view_api(config);
+ } catch (e: any) {
+ console.error(`Could not get api details: ${e.message}`);
+ }
+
+ return {
+ config,
+ ...return_obj,
+ };
+ }
+ let api: ApiInfo;
+ async function handle_space_sucess(status: SpaceStatus) {
+ if (status_callback) status_callback(status);
+ if (status.status === "running")
+ try {
+ config = await resolve_config(
+ fetch_implementation,
+ `${http_protocol}//${host}`,
+ hf_token
+ );
+
+ const _config: any = await config_success(config);
+ res(_config);
+ } catch (e) {
+ console.error(e);
+ if (status_callback) {
+ status_callback({
+ status: "error",
+ message: "Could not load this space.",
+ load_status: "error",
+ detail: "NOT_FOUND",
+ });
+ }
+ }
+ }
+
+ try {
+ config = await resolve_config(
+ fetch_implementation,
+ `${http_protocol}//${host}`,
+ hf_token
+ );
+
+ const _config: any = await config_success(config);
+ res(_config);
+ } catch (e) {
+ console.error(e);
+ if (space_id) {
+ check_space_status(
+ space_id,
+ RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
+ handle_space_sucess
+ );
+ } else {
+ if (status_callback)
+ status_callback({
+ status: "error",
+ message: "Could not load this space.",
+ load_status: "error",
+ detail: "NOT_FOUND",
+ });
+ }
+ }
+
+ /**
+ * Run a prediction.
+ * @param endpoint - The prediction endpoint to use.
+ * @param status_callback - A function that is called with the current status of the prediction immediately and every time it updates.
+ * @return Returns the data for the prediction or an error message.
+ */
+ function predict(
+ endpoint: string,
+ data: unknown[],
+ event_data?: unknown
+ ) {
+ let data_returned = false;
+ let status_complete = false;
+ return new Promise((res, rej) => {
+ const app = submit(endpoint, data, event_data);
+
+ app
+ .on("data", (d) => {
+ data_returned = true;
+ if (status_complete) {
+ app.destroy();
+ }
+ res(d);
+ })
+ .on("status", (status) => {
+ if (status.stage === "error") rej(status);
+ if (status.stage === "complete" && data_returned) {
+ app.destroy();
+ }
+ if (status.stage === "complete") {
+ status_complete = true;
+ }
+ });
+ });
+ }
+
+ function submit(
+ endpoint: string | number,
+ data: unknown[],
+ event_data?: unknown
+ ): SubmitReturn {
+ let fn_index: number;
+ let api_info: EndpointInfo;
+
+ if (typeof endpoint === "number") {
+ fn_index = endpoint;
+ api_info = api.unnamed_endpoints[fn_index];
+ } else {
+ const trimmed_endpoint = endpoint.replace(/^\//, "");
+
+ fn_index = api_map[trimmed_endpoint];
+ api_info = api.named_endpoints[endpoint.trim()];
+ }
+
+ if (typeof fn_index !== "number") {
+ throw new Error(
+ "There is no endpoint matching that name of fn_index matching that number."
+ );
+ }
+
+ let websocket: WebSocket;
+
+ const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
+ let payload: Payload;
+ let complete: false | Record = false;
+ const listener_map: ListenerMap = {};
+
+ handle_blob(
+ `${http_protocol}//${host + config.path}`,
+ data,
+ api_info,
+ hf_token
+ ).then((_payload) => {
+ payload = { data: _payload || [], event_data, fn_index };
+ if (skip_queue(fn_index, config)) {
+ fire_event({
+ type: "status",
+ endpoint: _endpoint,
+ stage: "pending",
+ queue: false,
+ fn_index,
+ time: new Date(),
+ });
+
+ post_data(
+ `${http_protocol}//${host + config.path}/run${
+ _endpoint.startsWith("/") ? _endpoint : `/${_endpoint}`
+ }`,
+ {
+ ...payload,
+ session_hash,
+ },
+ hf_token
+ )
+ .then(([output, status_code]) => {
+ const data = transform_files
+ ? transform_output(
+ output.data,
+ api_info,
+ config.root,
+ config.root_url
+ )
+ : output.data;
+ if (status_code == 200) {
+ fire_event({
+ type: "data",
+ endpoint: _endpoint,
+ fn_index,
+ data: data,
+ time: new Date(),
+ });
+
+ fire_event({
+ type: "status",
+ endpoint: _endpoint,
+ fn_index,
+ stage: "complete",
+ eta: output.average_duration,
+ queue: false,
+ time: new Date(),
+ });
+ } else {
+ fire_event({
+ type: "status",
+ stage: "error",
+ endpoint: _endpoint,
+ fn_index,
+ message: output.error,
+ queue: false,
+ time: new Date(),
+ });
+ }
+ })
+ .catch((e) => {
+ fire_event({
+ type: "status",
+ stage: "error",
+ message: e.message,
+ endpoint: _endpoint,
+ fn_index,
+ queue: false,
+ time: new Date(),
+ });
+ });
+ } else {
+ fire_event({
+ type: "status",
+ stage: "pending",
+ queue: true,
+ endpoint: _endpoint,
+ fn_index,
+ time: new Date(),
+ });
+
+ let url = new URL(`${ws_protocol}://${host}${config.path}
+ /queue/join`);
+
+ if (jwt) {
+ url.searchParams.set("__sign", jwt);
+ }
+
+ websocket = new WebSocket(url);
+
+ websocket.onclose = (evt) => {
+ if (!evt.wasClean) {
+ fire_event({
+ type: "status",
+ stage: "error",
+ message: BROKEN_CONNECTION_MSG,
+ queue: true,
+ endpoint: _endpoint,
+ fn_index,
+ time: new Date(),
+ });
+ }
+ };
+
+ websocket.onmessage = function (event) {
+ const _data = JSON.parse(event.data.toString());
+ const { type, status, data } = handle_message(
+ _data,
+ last_status[fn_index]
+ );
+
+ if (type === "update" && status && !complete) {
+ // call 'status' listeners
+ fire_event({
+ type: "status",
+ endpoint: _endpoint,
+ fn_index,
+ time: new Date(),
+ ...status,
+ });
+ if (status.stage === "error") {
+ websocket.close();
+ }
+ } else if (type === "hash") {
+ websocket.send(JSON.stringify({ fn_index, session_hash }));
+ return;
+ } else if (type === "data") {
+ websocket.send(JSON.stringify({ ...payload, session_hash }));
+ } else if (type === "complete") {
+ complete = status || false;
+ } else if (type === "generating") {
+ fire_event({
+ type: "status",
+ time: new Date(),
+ ...status,
+ stage: status?.stage!,
+ queue: true,
+ endpoint: _endpoint,
+ fn_index,
+ });
+ }
+ if (data) {
+ fire_event({
+ type: "data",
+ time: new Date(),
+ data: transform_files
+ ? transform_output(
+ data.data,
+ api_info,
+ config.root,
+ config.root_url
+ )
+ : data.data,
+ endpoint: _endpoint,
+ fn_index,
+ });
+
+ if (complete) {
+ fire_event({
+ type: "status",
+ time: new Date(),
+ ...complete,
+ stage: status?.stage!,
+ queue: true,
+ endpoint: _endpoint,
+ fn_index,
+ });
+ websocket.close();
+ }
+ }
+ };
+
+ // different ws contract for gradio versions older than 3.6.0
+ //@ts-ignore
+ if (semiver(config.version || "2.0.0", "3.6") < 0) {
+ addEventListener("open", () =>
+ websocket.send(JSON.stringify({ hash: session_hash }))
+ );
+ }
+ }
+ });
+
+ function fire_event(event: Event) {
+ const narrowed_listener_map: ListenerMap = listener_map;
+ const listeners = narrowed_listener_map[event.type] || [];
+ listeners?.forEach((l) => l(event));
+ }
+
+ function on(
+ eventType: K,
+ listener: EventListener
+ ) {
+ const narrowed_listener_map: ListenerMap = listener_map;
+ const listeners = narrowed_listener_map[eventType] || [];
+ narrowed_listener_map[eventType] = listeners;
+ listeners?.push(listener);
+
+ return { on, off, cancel, destroy };
+ }
+
+ function off(
+ eventType: K,
+ listener: EventListener
+ ) {
+ const narrowed_listener_map: ListenerMap = listener_map;
+ let listeners = narrowed_listener_map[eventType] || [];
+ listeners = listeners?.filter((l) => l !== listener);
+ narrowed_listener_map[eventType] = listeners;
+
+ return { on, off, cancel, destroy };
+ }
+
+ async function cancel() {
+ const _status: Status = {
+ stage: "complete",
+ queue: false,
+ time: new Date(),
+ };
+ complete = _status;
+ fire_event({
+ ..._status,
+ type: "status",
+ endpoint: _endpoint,
+ fn_index: fn_index,
+ });
+
+ if (websocket && websocket.readyState === 0) {
+ websocket.addEventListener("open", () => {
+ websocket.close();
+ });
+ } else {
+ websocket.close();
+ }
+
+ try {
+ await fetch_implementation(
+ `${http_protocol}//${host + config.path}/reset`,
+ {
+ headers: { "Content-Type": "application/json" },
+ method: "POST",
+ body: JSON.stringify({ fn_index, session_hash }),
+ }
+ );
+ } catch (e) {
+ console.warn(
+ "The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
+ );
+ }
+ }
+
+ function destroy() {
+ for (const event_type in listener_map) {
+ listener_map[event_type as "data" | "status"]?.forEach((fn) => {
+ off(event_type as "data" | "status", fn);
+ });
+ }
+ }
+
+ return {
+ on,
+ off,
+ cancel,
+ destroy,
+ };
+ }
+
+ async function view_api(config?: Config): Promise> {
+ if (api) return api;
+
+ const headers: {
+ Authorization?: string;
+ "Content-Type": "application/json";
+ } = { "Content-Type": "application/json" };
+ if (hf_token) {
+ headers.Authorization = `Bearer ${hf_token}`;
+ }
+ let response: Response;
+ // @ts-ignore
+ if (semiver(config.version || "2.0.0", "3.30") < 0) {
+ response = await fetch_implementation(
+ "https://gradio-space-api-fetcher-v2.hf.space/api",
+ {
+ method: "POST",
+ body: JSON.stringify({
+ serialize: false,
+ config: JSON.stringify(config),
+ }),
+ headers,
+ }
+ );
+ } else {
+ response = await fetch_implementation(`${config?.root}/info`, {
+ headers,
+ });
+ }
+
+ if (!response.ok) {
+ throw new Error(BROKEN_CONNECTION_MSG);
+ }
+
+ let api_info = (await response.json()) as
+ | ApiInfo
+ | { api: ApiInfo };
+ if ("api" in api_info) {
+ api_info = api_info.api;
+ }
+
+ if (
+ api_info.named_endpoints["/predict"] &&
+ !api_info.unnamed_endpoints["0"]
+ ) {
+ api_info.unnamed_endpoints[0] = api_info.named_endpoints["/predict"];
+ }
+
+ const x = transform_api_info(api_info, config as Config, api_map);
+ return x;
+ }
+ });
+ }
+
+ async function handle_blob(
+ endpoint: string,
+ data: unknown[],
+ api_info: any,
+ token?: `hf_${string}`
+ ): Promise {
+ const blob_refs = await walk_and_store_blobs(
+ data,
+ undefined,
+ [],
+ true,
+ api_info
+ );
+
+ return Promise.all(
+ blob_refs.map(
+ async ({
+ path,
+ blob,
+ data,
+ type,
+ }: {
+ path: string;
+ blob: Blob;
+ data: string;
+ type: string;
+ }) => {
+ if (blob) {
+ const file_url = (await upload_files(endpoint, [blob], token))
+ ?.files?.[0];
+ return { path, file_url, type };
+ } else {
+ return { path, base64: data, type };
+ }
+ }
+ )
+ ).then((r) => {
+ r.forEach(({ path, file_url, base64, type }) => {
+ if (base64) {
+ update_object(data, base64, path);
+ } else if (type === "Gallery") {
+ update_object(data, file_url, path);
+ } else if (file_url) {
+ const o = {
+ is_file: true,
+ name: `${file_url}`,
+ data: null,
+ // orig_name: "file.csv"
+ };
+ update_object(data, o, path);
+ }
+ });
+
+ return data;
+ });
+ }
+}
+
+export const { post_data, upload_files, client, handle_blob } =
+ api_factory(fetch);
+
+function transform_output(
+ data: any[],
+ api_info: any,
+ root_url: string,
+ remote_url: string | null = null
+): unknown[] {
+ return data.map((d, i) => {
+ if (api_info.returns?.[i]?.component === "File") {
+ return normalise_file(d, root_url, remote_url);
+ } else if (api_info.returns?.[i]?.component === "Gallery") {
+ return d.map((img: any) => {
+ return Array.isArray(img)
+ ? [normalise_file(img[0], root_url, remote_url), img[1]]
+ : [normalise_file(img, root_url, remote_url), null];
+ });
+ } else if (typeof d === "object" && d.is_file) {
+ return normalise_file(d, root_url, remote_url);
+ } else {
+ return d;
+ }
+ });
+}
+
+function normalise_file(
+ file: Array,
+ root: string,
+ root_url: string | null
+): Array;
+function normalise_file(
+ file: FileData | string,
+ root: string,
+ root_url: string | null
+): FileData;
+function normalise_file(
+ file: null,
+ root: string,
+ root_url: string | null
+): null;
+function normalise_file(
+ file: any,
+ root: any,
+ root_url: any
+): Array | FileData | null {
+ if (file == null) return null;
+ if (typeof file === "string") {
+ return {
+ name: "file_data",
+ data: file,
+ };
+ } else if (Array.isArray(file)) {
+ const normalized_file: Array = [];
+
+ for (const x of file) {
+ if (x === null) {
+ normalized_file.push(null);
+ } else {
+ normalized_file.push(normalise_file(x, root, root_url));
+ }
+ }
+
+ return normalized_file as Array;
+ } else if (file.is_file) {
+ if (!root_url) {
+ file.data = root + "/file=" + file.name;
+ } else {
+ file.data = "/proxy=" + root_url + "file=" + file.name;
+ }
+ }
+ return file;
+}
+
+interface ApiData {
+ label: string;
+ type: {
+ type: any;
+ description: string;
+ };
+ component: string;
+ example_input?: any;
+}
+
+interface JsApiData {
+ label: string;
+ type: string;
+ component: string;
+ example_input: any;
+}
+
+interface EndpointInfo {
+ parameters: T[];
+ returns: T[];
+}
+interface ApiInfo {
+ named_endpoints: {
+ [key: string]: EndpointInfo;
+ };
+ unnamed_endpoints: {
+ [key: string]: EndpointInfo;
+ };
+}
+
+function get_type(
+ type: { [key: string]: any },
+ component: string,
+ serializer: string,
+ signature_type: "return" | "parameter"
+) {
+ switch (type.type) {
+ case "string":
+ return "string";
+ case "boolean":
+ return "boolean";
+ case "number":
+ return "number";
+ }
+
+ if (
+ serializer === "JSONSerializable" ||
+ serializer === "StringSerializable"
+ ) {
+ return "any";
+ } else if (serializer === "ListStringSerializable") {
+ return "string[]";
+ } else if (component === "Image") {
+ return signature_type === "parameter" ? "Blob | File | Buffer" : "string";
+ } else if (serializer === "FileSerializable") {
+ if (type?.type === "array") {
+ return signature_type === "parameter"
+ ? "(Blob | File | Buffer)[]"
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}[]`;
+ } else {
+ return signature_type === "parameter"
+ ? "Blob | File | Buffer"
+ : `{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}`;
+ }
+ } else if (serializer === "GallerySerializable") {
+ return signature_type === "parameter"
+ ? "[(Blob | File | Buffer), (string | null)][]"
+ : `[{ name: string; data: string; size?: number; is_file?: boolean; orig_name?: string}, (string | null))][]`;
+ }
+}
+
+function get_description(
+ type: { type: any; description: string },
+ serializer: string
+) {
+ if (serializer === "GallerySerializable") {
+ return "array of [file, label] tuples";
+ } else if (serializer === "ListStringSerializable") {
+ return "array of strings";
+ } else if (serializer === "FileSerializable") {
+ return "array of files or single file";
+ } else {
+ return type.description;
+ }
+}
+
+function transform_api_info(
+ api_info: ApiInfo,
+ config: Config,
+ api_map: Record
+): ApiInfo {
+ let new_data = {
+ named_endpoints: {},
+ unnamed_endpoints: {},
+ };
+ let key: keyof ApiInfo;
+ for (key in api_info) {
+ const cat = api_info[key];
+
+ for (const endpoint in cat) {
+ const dep_index = config.dependencies[endpoint as any]
+ ? endpoint
+ : api_map[endpoint.replace("/", "")];
+
+ const info = cat[endpoint];
+ // @ts-ignore
+ new_data[key][endpoint] = {};
+ // @ts-ignore
+ new_data[key][endpoint].parameters = {};
+ // @ts-ignore
+ new_data[key][endpoint].returns = {};
+ // @ts-ignore
+ new_data[key][endpoint].type =
+ config.dependencies[dep_index as any].types;
+ // @ts-ignore
+ new_data[key][endpoint].parameters = info.parameters.map(
+ // @ts-ignore
+ ({ label, component, type, serializer }) => ({
+ label,
+ component,
+ type: get_type(type, component, serializer, "parameter"),
+ description: get_description(type, serializer),
+ })
+ );
+ // @ts-ignore
+ new_data[key][endpoint].returns = info.returns.map(
+ // @ts-ignore
+ ({ label, component, type, serializer }) => ({
+ label,
+ component,
+ type: get_type(type, component, serializer, "return"),
+ description: get_description(type, serializer),
+ })
+ );
+ }
+ }
+
+ return new_data;
+}
+
+async function get_jwt(
+ space: string,
+ token: `hf_${string}`
+): Promise {
+ try {
+ const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
+ headers: {
+ Authorization: `Bearer ${token}`,
+ },
+ });
+
+ const jwt = (await r.json()).token;
+
+ return jwt || false;
+ } catch (e) {
+ console.error(e);
+ return false;
+ }
+}
+
+function update_object(object: any, newValue: any, stack: any) {
+ while (stack.length > 1) {
+ object = object[stack.shift()];
+ }
+
+ object[stack.shift()] = newValue;
+}
+
+export async function walk_and_store_blobs(
+ param: any,
+ type = undefined,
+ path: any[] = [],
+ root = false,
+ api_info: any = undefined
+) {
+ if (Array.isArray(param)) {
+ let blob_refs: any[] = [];
+
+ await Promise.all(
+ param.map(async (v, i) => {
+ let new_path = path.slice();
+ new_path.push(i);
+
+ const array_refs = await walk_and_store_blobs(
+ param[i],
+ root ? api_info?.parameters[i]?.component || undefined : type,
+ new_path,
+ false,
+ api_info
+ );
+
+ blob_refs = blob_refs.concat(array_refs);
+ })
+ );
+
+ return blob_refs;
+ } else if (globalThis.Buffer && param instanceof globalThis.Buffer) {
+ const is_image = type === "Image";
+ return [
+ {
+ path: path,
+ blob: is_image ? false : new NodeBlob([param]),
+ data: is_image ? `${param.toString("base64")}` : false,
+ type,
+ },
+ ];
+ } else if (
+ param instanceof Blob ||
+ (typeof window !== "undefined" && param instanceof File)
+ ) {
+ if (type === "Image") {
+ let data;
+
+ if (typeof window !== "undefined") {
+ // browser
+ data = await image_to_data_uri(param);
+ } else {
+ const buffer = await param.arrayBuffer();
+ data = Buffer.from(buffer).toString("base64");
+ }
+
+ return [{ path, data, type }];
+ } else {
+ return [{ path: path, blob: param, type }];
+ }
+ } else if (typeof param === "object") {
+ let blob_refs: any[] = [];
+ for (let key in param) {
+ if (param.hasOwnProperty(key)) {
+ let new_path: any[] = path.slice();
+ new_path.push(key);
+ blob_refs = blob_refs.concat(
+ await walk_and_store_blobs(
+ param[key],
+ undefined,
+ new_path,
+ false,
+ api_info
+ )
+ );
+ }
+ }
+ return blob_refs;
+ } else {
+ return [];
+ }
+}
+
+function image_to_data_uri(blob: Blob) {
+ return new Promise((resolve, _) => {
+ const reader = new FileReader();
+ reader.onloadend = () => resolve(reader.result);
+ reader.readAsDataURL(blob);
+ });
+}
+
+function skip_queue(id: number, config: Config) {
+ return (
+ !(config?.dependencies?.[id]?.queue === null
+ ? config.enable_queue
+ : config?.dependencies?.[id]?.queue) || false
+ );
+}
+
+async function resolve_config(
+ fetch_implementation: typeof fetch,
+ endpoint?: string,
+ token?: `hf_${string}`
+): Promise {
+ const headers: { Authorization?: string } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+ if (
+ typeof window !== "undefined" &&
+ window.gradio_config &&
+ location.origin !== "http://localhost:9876"
+ ) {
+ const path = window.gradio_config.root;
+ const config = window.gradio_config;
+ config.root = endpoint + config.root;
+ return { ...config, path: path };
+ } else if (endpoint) {
+ let response = await fetch_implementation(`${endpoint}/config`, {
+ headers,
+ });
+
+ if (response.status === 200) {
+ const config = await response.json();
+ config.path = config.path ?? "";
+ config.root = endpoint;
+ return config;
+ } else {
+ throw new Error("Could not get config.");
+ }
+ }
+
+ throw new Error("No config or app endpoint found");
+}
+
+async function check_space_status(
+ id: string,
+ type: "subdomain" | "space_name",
+ status_callback: SpaceStatusCallback
+) {
+ let endpoint =
+ type === "subdomain"
+ ? `https://huggingface.co/api/spaces/by-subdomain/${id}`
+ : `https://huggingface.co/api/spaces/${id}`;
+ let response;
+ let _status;
+ try {
+ response = await fetch(endpoint);
+ _status = response.status;
+ if (_status !== 200) {
+ throw new Error();
+ }
+ response = await response.json();
+ } catch (e) {
+ status_callback({
+ status: "error",
+ load_status: "error",
+ message: "Could not get space status",
+ detail: "NOT_FOUND",
+ });
+ return;
+ }
+
+ if (!response || _status !== 200) return;
+ const {
+ runtime: { stage },
+ id: space_name,
+ } = response;
+
+ switch (stage) {
+ case "STOPPED":
+ case "SLEEPING":
+ status_callback({
+ status: "sleeping",
+ load_status: "pending",
+ message: "Space is asleep. Waking it up...",
+ detail: stage,
+ });
+
+ setTimeout(() => {
+ check_space_status(id, type, status_callback);
+ }, 1000); // poll for status
+ break;
+ case "PAUSED":
+ status_callback({
+ status: "paused",
+ load_status: "error",
+ message:
+ "This space has been paused by the author. If you would like to try this demo, consider duplicating the space.",
+ detail: stage,
+ discussions_enabled: await discussions_enabled(space_name),
+ });
+ break;
+ case "RUNNING":
+ case "RUNNING_BUILDING":
+ status_callback({
+ status: "running",
+ load_status: "complete",
+ message: "",
+ detail: stage,
+ });
+ // load_config(source);
+ // launch
+ break;
+ case "BUILDING":
+ status_callback({
+ status: "building",
+ load_status: "pending",
+ message: "Space is building...",
+ detail: stage,
+ });
+
+ setTimeout(() => {
+ check_space_status(id, type, status_callback);
+ }, 1000);
+ break;
+ default:
+ status_callback({
+ status: "space_error",
+ load_status: "error",
+ message: "This space is experiencing an issue.",
+ detail: stage,
+ discussions_enabled: await discussions_enabled(space_name),
+ });
+ break;
+ }
+}
+
+function handle_message(
+ data: any,
+ last_status: Status["stage"]
+): {
+ type: "hash" | "data" | "update" | "complete" | "generating" | "none";
+ data?: any;
+ status?: Status;
+} {
+ const queue = true;
+ switch (data.msg) {
+ case "send_data":
+ return { type: "data" };
+ case "send_hash":
+ return { type: "hash" };
+ case "queue_full":
+ return {
+ type: "update",
+ status: {
+ queue,
+ message: QUEUE_FULL_MSG,
+ stage: "error",
+ code: data.code,
+ success: data.success,
+ },
+ };
+ case "estimation":
+ return {
+ type: "update",
+ status: {
+ queue,
+ stage: last_status || "pending",
+ code: data.code,
+ size: data.queue_size,
+ position: data.rank,
+ eta: data.rank_eta,
+ success: data.success,
+ },
+ };
+ case "progress":
+ return {
+ type: "update",
+ status: {
+ queue,
+ stage: "pending",
+ code: data.code,
+ progress_data: data.progress_data,
+ success: data.success,
+ },
+ };
+ case "process_generating":
+ return {
+ type: "generating",
+ status: {
+ queue,
+ message: !data.success ? data.output.error : null,
+ stage: data.success ? "generating" : "error",
+ code: data.code,
+ progress_data: data.progress_data,
+ eta: data.average_duration,
+ },
+ data: data.success ? data.output : null,
+ };
+ case "process_completed":
+ if ("error" in data.output) {
+ return {
+ type: "update",
+ status: {
+ queue,
+ message: data.output.error as string,
+ stage: "error",
+ code: data.code,
+ success: data.success,
+ },
+ };
+ } else {
+ return {
+ type: "complete",
+ status: {
+ queue,
+ message: !data.success ? data.output.error : undefined,
+ stage: data.success ? "complete" : "error",
+ code: data.code,
+ progress_data: data.progress_data,
+ eta: data.output.average_duration,
+ },
+ data: data.success ? data.output : null,
+ };
+ }
+
+ case "process_starts":
+ return {
+ type: "update",
+ status: {
+ queue,
+ stage: "pending",
+ code: data.code,
+ size: data.rank,
+ position: 0,
+ success: data.success,
+ },
+ };
+ }
+
+ return { type: "none", status: { stage: "error", queue } };
+}
diff --git a/src/gradio/globals.d.ts b/src/gradio/globals.d.ts
new file mode 100644
index 0000000..9abbf76
--- /dev/null
+++ b/src/gradio/globals.d.ts
@@ -0,0 +1,31 @@
+declare global {
+ interface Window {
+ __gradio_mode__: "app" | "website";
+ launchGradio: Function;
+ launchGradioFromSpaces: Function;
+ gradio_config: Config;
+ scoped_css_attach: (link: HTMLLinkElement) => void;
+ __is_colab__: boolean;
+ }
+}
+
+export interface Config {
+ auth_required: boolean | undefined;
+ auth_message: string;
+ components: any[];
+ css: string | null;
+ dependencies: any[];
+ dev_mode: boolean;
+ enable_queue: boolean;
+ layout: any;
+ mode: "blocks" | "interface";
+ root: string;
+ theme: string;
+ title: string;
+ version: string;
+ space_id: string | null;
+ is_colab: boolean;
+ show_api: boolean;
+ stylesheets: string[];
+ path: string;
+}
diff --git a/src/gradio/index.ts b/src/gradio/index.ts
new file mode 100644
index 0000000..6061207
--- /dev/null
+++ b/src/gradio/index.ts
@@ -0,0 +1,8 @@
+export {
+ client,
+ post_data,
+ upload_files,
+ duplicate,
+ api_factory,
+} from "./client";
+export type { SpaceStatus } from "./types";
diff --git a/src/gradio/types.ts b/src/gradio/types.ts
new file mode 100644
index 0000000..9e283cf
--- /dev/null
+++ b/src/gradio/types.ts
@@ -0,0 +1,109 @@
+export interface Config {
+ auth_required: boolean | undefined;
+ auth_message: string;
+ components: any[];
+ css: string | null;
+ dependencies: any[];
+ dev_mode: boolean;
+ enable_queue: boolean;
+ layout: any;
+ mode: "blocks" | "interface";
+ root: string;
+ root_url?: string;
+ theme: string;
+ title: string;
+ version: string;
+ space_id: string | null;
+ is_colab: boolean;
+ show_api: boolean;
+ stylesheets: string[];
+ path: string;
+}
+
+export interface Payload {
+ data: Array;
+ fn_index?: number;
+ event_data?: unknown;
+ time?: Date;
+}
+
+export interface PostResponse {
+ error?: string;
+ [x: string]: any;
+}
+export interface UploadResponse {
+ error?: string;
+ files?: Array;
+}
+
+export interface Status {
+ queue: boolean;
+ code?: string;
+ success?: boolean;
+ stage: "pending" | "error" | "complete" | "generating";
+ size?: number;
+ position?: number;
+ eta?: number;
+ message?: string;
+ progress_data?: Array<{
+ progress: number | null;
+ index: number | null;
+ length: number | null;
+ unit: string | null;
+ desc: string | null;
+ }>;
+ time?: Date;
+}
+
+export interface SpaceStatusNormal {
+ status: "sleeping" | "running" | "building" | "error" | "stopped";
+ detail:
+ | "SLEEPING"
+ | "RUNNING"
+ | "RUNNING_BUILDING"
+ | "BUILDING"
+ | "NOT_FOUND";
+ load_status: "pending" | "error" | "complete" | "generating";
+ message: string;
+}
+export interface SpaceStatusError {
+ status: "space_error" | "paused";
+ detail:
+ | "NO_APP_FILE"
+ | "CONFIG_ERROR"
+ | "BUILD_ERROR"
+ | "RUNTIME_ERROR"
+ | "PAUSED";
+ load_status: "error";
+ message: string;
+ discussions_enabled: boolean;
+}
+export type SpaceStatus = SpaceStatusNormal | SpaceStatusError;
+
+export type status_callback_function = (a: Status) => void;
+export type SpaceStatusCallback = (a: SpaceStatus) => void;
+
+export type EventType = "data" | "status";
+
+export interface EventMap {
+ data: Payload;
+ status: Status;
+}
+
+export type Event = {
+ [P in K]: EventMap[P] & { type: P; endpoint: string; fn_index: number };
+}[K];
+export type EventListener = (event: Event) => void;
+export type ListenerMap = {
+ [P in K]?: EventListener[];
+};
+export interface FileData {
+ name: string;
+ orig_name?: string;
+ size?: number;
+ data: string;
+ blob?: File;
+ is_file?: boolean;
+ mime_type?: string;
+ alt_text?: string;
+}
diff --git a/src/gradio/utils.ts b/src/gradio/utils.ts
new file mode 100644
index 0000000..7340081
--- /dev/null
+++ b/src/gradio/utils.ts
@@ -0,0 +1,211 @@
+import type { Config } from "./types";
+
+export function determine_protocol(endpoint: string): {
+ ws_protocol: "ws" | "wss";
+ http_protocol: "http:" | "https:";
+ host: string;
+} {
+ if (endpoint.startsWith("http")) {
+ const { protocol, host } = new URL(endpoint);
+
+ if (host.endsWith("hf.space")) {
+ return {
+ ws_protocol: "wss",
+ host: host,
+ http_protocol: protocol as "http:" | "https:",
+ };
+ } else {
+ return {
+ ws_protocol: protocol === "https:" ? "wss" : "ws",
+ http_protocol: protocol as "http:" | "https:",
+ host,
+ };
+ }
+ }
+
+ // default to secure if no protocol is provided
+ return {
+ ws_protocol: "wss",
+ http_protocol: "https:",
+ host: endpoint,
+ };
+}
+
+export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/;
+export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/;
+export async function process_endpoint(
+ app_reference: string,
+ token?: `hf_${string}`
+): Promise<{
+ space_id: string | false;
+ host: string;
+ ws_protocol: "ws" | "wss";
+ http_protocol: "http:" | "https:";
+}> {
+ const headers: { Authorization?: string } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+
+ const _app_reference = app_reference.trim();
+
+ if (RE_SPACE_NAME.test(_app_reference)) {
+ try {
+ const res = await fetch(
+ `https://huggingface.co/api/spaces/${_app_reference}/host`,
+ { headers }
+ );
+
+ if (res.status !== 200)
+ throw new Error("Space metadata could not be loaded.");
+ const _host = (await res.json()).host;
+
+ return {
+ space_id: app_reference,
+ ...determine_protocol(_host),
+ };
+ } catch (e: any) {
+ throw new Error("Space metadata could not be loaded." + e.message);
+ }
+ }
+
+ if (RE_SPACE_DOMAIN.test(_app_reference)) {
+ const { ws_protocol, http_protocol, host } =
+ determine_protocol(_app_reference);
+
+ return {
+ space_id: host.replace(".hf.space", ""),
+ ws_protocol,
+ http_protocol,
+ host,
+ };
+ }
+
+ return {
+ space_id: false,
+ ...determine_protocol(_app_reference),
+ };
+}
+
+export function map_names_to_ids(fns: Config["dependencies"]) {
+ let apis: Record = {};
+
+ fns.forEach(({ api_name }, i) => {
+ if (api_name) apis[api_name] = i;
+ });
+
+ return apis;
+}
+
+const RE_DISABLED_DISCUSSION =
+ /^(?=[^]*\b[dD]iscussions{0,1}\b)(?=[^]*\b[dD]isabled\b)[^]*$/;
+export async function discussions_enabled(space_id: string) {
+ try {
+ const r = await fetch(
+ `https://huggingface.co/api/spaces/${space_id}/discussions`,
+ {
+ method: "HEAD",
+ }
+ );
+ const error = r.headers.get("x-error-message");
+
+ if (error && RE_DISABLED_DISCUSSION.test(error)) return false;
+ else return true;
+ } catch (e) {
+ return false;
+ }
+}
+
+export async function get_space_hardware(
+ space_id: string,
+ token: `hf_${string}`
+) {
+ const headers: { Authorization?: string } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+
+ try {
+ const res = await fetch(
+ `https://huggingface.co/api/spaces/${space_id}/runtime`,
+ { headers }
+ );
+
+ if (res.status !== 200)
+ throw new Error("Space hardware could not be obtained.");
+
+ const { hardware } = await res.json();
+
+ return hardware;
+ } catch (e: any) {
+ throw new Error(e.message);
+ }
+}
+
+export async function set_space_hardware(
+ space_id: string,
+ new_hardware: (typeof hardware_types)[number],
+ token: `hf_${string}`
+) {
+ const headers: { Authorization?: string } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+
+ try {
+ const res = await fetch(
+ `https://huggingface.co/api/spaces/${space_id}/hardware`,
+ { headers, body: JSON.stringify(new_hardware) }
+ );
+
+ if (res.status !== 200)
+ throw new Error(
+ "Space hardware could not be set. Please ensure the space hardware provided is valid and that a Hugging Face token is passed in."
+ );
+
+ const { hardware } = await res.json();
+
+ return hardware;
+ } catch (e: any) {
+ throw new Error(e.message);
+ }
+}
+
+export async function set_space_timeout(
+ space_id: string,
+ timeout: number,
+ token: `hf_${string}`
+) {
+ const headers: { Authorization?: string } = {};
+ if (token) {
+ headers.Authorization = `Bearer ${token}`;
+ }
+
+ try {
+ const res = await fetch(
+ `https://huggingface.co/api/spaces/${space_id}/hardware`,
+ { headers, body: JSON.stringify({ seconds: timeout }) }
+ );
+
+ if (res.status !== 200)
+ throw new Error(
+ "Space hardware could not be set. Please ensure the space hardware provided is valid and that a Hugging Face token is passed in."
+ );
+
+ const { hardware } = await res.json();
+
+ return hardware;
+ } catch (e: any) {
+ throw new Error(e.message);
+ }
+}
+
+export const hardware_types = [
+ "cpu-basic",
+ "cpu-upgrade",
+ "t4-small",
+ "t4-medium",
+ "a10g-small",
+ "a10g-large",
+ "a100-large",
+] as const;
diff --git a/src/index.ts b/src/index.ts
index a5f57f8..6da0ecf 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -6,3 +6,4 @@ export * from "./midjourne.api";
export * from "./command";
export * from "./verify.human";
export * from "./banned.words";
+export * from "./face.swap";
diff --git a/src/interfaces/message.ts b/src/interfaces/message.ts
index f4730de..c18638f 100644
--- a/src/interfaces/message.ts
+++ b/src/interfaces/message.ts
@@ -47,6 +47,14 @@ export interface MJSettings {
flags: number;
options: MJOptions[];
}
+export interface MJDescribe {
+ id: string;
+ flags: number;
+ uri: string;
+ proxy_url?: string;
+ options: MJOptions[];
+ descriptions: string[];
+}
export interface MJShorten {
description: string;
diff --git a/src/midjourne.api.ts b/src/midjourne.api.ts
index a6344bb..80082b6 100644
--- a/src/midjourne.api.ts
+++ b/src/midjourne.api.ts
@@ -359,7 +359,7 @@ export class MidjourneyApi extends Command {
return resp;
}
- async UploadImageByBole(blob: Blob, filename = "image.png") {
+ async UploadImageByBole(blob: Blob, filename = nextNonce() + ".png") {
const fileData = await blob.arrayBuffer();
const mimeType = blob.type;
const file_size = fileData.byteLength;
@@ -435,4 +435,30 @@ export class MidjourneyApi extends Command {
const payload = await this.describePayload(image, nonce);
return this.safeIteractions(payload);
}
+ async upImageApi(image: DiscordImage, nonce?: string) {
+ const { SalaiToken, DiscordBaseUrl, ChannelId, fetch } = this.config;
+ const payload = {
+ content: "",
+ nonce,
+ channel_id: ChannelId,
+ type: 0,
+ sticker_ids: [],
+ attachments: [image],
+ };
+
+ const url = new URL(
+ `${DiscordBaseUrl}/api/v9/channels/${ChannelId}/messages`
+ );
+ const headers = {
+ Authorization: SalaiToken,
+ "content-type": "application/json",
+ };
+ const response = await fetch(url, {
+ headers,
+ method: "POST",
+ body: JSON.stringify(payload),
+ });
+
+ return response.status;
+ }
}
diff --git a/src/midjourney.ts b/src/midjourney.ts
index 0c080ac..5e9414c 100644
--- a/src/midjourney.ts
+++ b/src/midjourney.ts
@@ -6,8 +6,15 @@ import {
} from "./interfaces";
import { MidjourneyApi } from "./midjourne.api";
import { MidjourneyMessage } from "./discord.message";
-import { toRemixCustom, custom2Type, nextNonce, random } from "./utls";
+import {
+ toRemixCustom,
+ custom2Type,
+ nextNonce,
+ random,
+ base64ToBlob,
+} from "./utls";
import { WsMessage } from "./discord.ws";
+import { faceSwap } from "./face.swap";
export class Midjourney extends MidjourneyMessage {
public config: MJConfig;
private wsClient?: WsMessage;
@@ -156,6 +163,7 @@ export class Midjourney extends MidjourneyMessage {
const wsClient = await this.getWsClient();
const nonce = nextNonce();
const DcImage = await this.MJApi.UploadImageByUri(imgUri);
+ this.log(`Describe`, DcImage);
const httpStatus = await this.MJApi.DescribeApi(DcImage, nonce);
if (httpStatus !== 204) {
throw new Error(`DescribeApi failed with status ${httpStatus}`);
@@ -361,6 +369,24 @@ export class Midjourney extends MidjourneyMessage {
});
}
+ async FaceSwap(target: string, source: string) {
+ const wsClient = await this.getWsClient();
+ const app = new faceSwap(this.config.HuggingFaceToken);
+ const Target = await (await this.config.fetch(target)).blob();
+ const Source = await (await this.config.fetch(source)).blob();
+ const res = await app.changeFace(Target, Source);
+ this.log(res[0]);
+ const blob = await base64ToBlob(res[0] as string);
+ const DcImage = await this.MJApi.UploadImageByBole(blob);
+ const nonce = nextNonce();
+ const httpStatus = await this.MJApi.DescribeApi(DcImage, nonce);
+ if (httpStatus !== 204) {
+ throw new Error(`DescribeApi failed with status ${httpStatus}`);
+ }
+ const describe = await wsClient.waitDescribe(nonce);
+ return describe?.uri;
+ }
+
Close() {
if (this.wsClient) {
this.wsClient.close();
diff --git a/src/utls/index.ts b/src/utls/index.ts
index c61f4ff..f912e4f 100644
--- a/src/utls/index.ts
+++ b/src/utls/index.ts
@@ -152,3 +152,24 @@ export const toRemixCustom = (customID: string) => {
const convertedString = `MJ::RemixModal::${parts[4]}::${parts[3]}::1`;
return convertedString;
};
+
+export async function base64ToBlob(base64Image: string): Promise {
+ // 移除 base64 图像头部信息
+ const base64Data = base64Image.replace(
+ /^data:image\/(png|jpeg|jpg);base64,/,
+ ""
+ );
+
+ // 将 base64 数据解码为二进制数据
+ const binaryData = atob(base64Data);
+
+ // 创建一个 Uint8Array 来存储二进制数据
+ const arrayBuffer = new ArrayBuffer(binaryData.length);
+ const uint8Array = new Uint8Array(arrayBuffer);
+ for (let i = 0; i < binaryData.length; i++) {
+ uint8Array[i] = binaryData.charCodeAt(i);
+ }
+
+ // 使用 Uint8Array 创建 Blob 对象
+ return new Blob([uint8Array], { type: "image/png" }); // 替换为相应的 MIME 类型
+}
diff --git a/test/face.ts b/test/face.ts
new file mode 100644
index 0000000..3e56998
--- /dev/null
+++ b/test/face.ts
@@ -0,0 +1,25 @@
+import "dotenv/config";
+import { faceSwap } from "../src";
+/**
+ *
+ * ```
+ * npx tsx test/face.ts
+ * ```
+ */
+
+async function test2() {
+ const app = new faceSwap(process.env.HuggingFaceToken);
+ const Target = await (
+ await fetch(
+ "https://cdn.discordapp.com/attachments/1108587422389899304/1129321837042602016/guapitu006_a_girls_face_with_david_bowies_thunderbolt_71ee5899-bd45-4fc4-8c9d-92f19ddb0a03.png"
+ )
+ ).blob();
+ const Source = await (
+ await fetch(
+ "https://cdn.discordapp.com/attachments/1108587422389899304/1129321826804306031/guapitu006_Cute_warrior_girl_in_the_style_of_Baten_Kaitos__111f39bc-329e-4fab-9af7-ee219fedf260.png"
+ )
+ ).blob();
+
+ await app.changeFace(Target, Source);
+}
+test2();
diff --git a/yarn.lock b/yarn.lock
index d3ba752..ed5462e 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -296,6 +296,11 @@ prettier@^2.8.8:
resolved "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz"
integrity sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==
+semiver@^1.1.0:
+ version "1.1.0"
+ resolved "https://registry.yarnpkg.com/semiver/-/semiver-1.1.0.tgz#9c97fb02c21c7ce4fcf1b73e2c7a24324bdddd5f"
+ integrity sha512-QNI2ChmuioGC1/xjyYwyZYADILWyW6AmS1UH6gDj/SFUUUS4MBAWs/7mxnkRPc/F4iHezDP+O8t0dO8WHiEOdg==
+
snowyflake@^2.0.0:
version "2.0.0"
resolved "https://registry.npmjs.org/snowyflake/-/snowyflake-2.0.0.tgz"