Skip to content

Commit

Permalink
generation requests return images
Browse files Browse the repository at this point in the history
  • Loading branch information
KAJdev committed Jul 20, 2023
1 parent 4d435a4 commit e7089b8
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 214 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/comfy_windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
branches:
- tauri

concurrency:
group: comfyui_windows

jobs:
repackage_comfyui:
permissions:
Expand Down
1 change: 0 additions & 1 deletion packages/stablestudio-ui/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]

use std::collections::HashMap;
use std::fmt::format;
use std::fs::File;
use std::sync::OnceLock;
use tauri::api::process::CommandEvent;
Expand Down
5 changes: 5 additions & 0 deletions packages/stablestudio-ui/src-tauri/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ impl Builder {
"`ws${window.location.protocol === \"https:\" ? \"s\" : \"\"}://${location.host}/ws${existingSession}`",
"`ws://localhost:5000/ws${existingSession}`"
);

// add some stuff to app.js
if path_name.ends_with("app.js") && !file_contents.ends_with("app.api = api;") {
file_contents = file_contents + "\napp.api = api;";
}

let response = HttpResponse::from_data(file_contents.as_bytes()).with_status_code(200).with_header(
Header::from_bytes("Content-Type", mimetype.unwrap().as_str()).unwrap(),
Expand Down
1 change: 1 addition & 0 deletions packages/stablestudio-ui/src/App/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ export namespace App {

setRunning(true);
setIsSetup(SetupState.ComfyRunning);
Comfy.registerListeners();
}, [isSetup, print, setRunning, setUnlisteners]);

useEffect(() => {
Expand Down
173 changes: 169 additions & 4 deletions packages/stablestudio-ui/src/Comfy/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import * as StableStudio from "@stability/stablestudio-plugin";
import { useLocation } from "react-router-dom";
import { create } from "zustand";
import { Generation } from "~/Generation";

export type Comfy = {
export type ComfyApp = {
setup: () => void;
registerNodes: () => void;
loadGraphData: (graph: Graph) => void;
Expand All @@ -13,6 +15,21 @@ export type Comfy = {
refreshComboInNodes: () => Promise<void>;
queuePrompt: (number: number, batchCount: number) => Promise<void>;
clean: () => void;
api: ComfyAPI;
};

export type ComfyAPI = {
addEventListener: (event: string, callback: (detail: any) => void) => void;
};

export type Comfy = { app: ComfyApp; api: ComfyAPI };

export type ComfyOutput = {
images: {
filename: string;
subfolder: string;
type: string;
}[];
};

export type Graph = {
Expand Down Expand Up @@ -91,11 +108,11 @@ type State = {
};

export namespace Comfy {
export const get = (): Comfy | null =>
export const get = (): ComfyApp | null =>
((
(document.getElementById("comfyui-window") as HTMLIFrameElement)
?.contentWindow as Window & { app: Comfy }
)?.app as Comfy) ?? null;
?.contentWindow as Window & { app: ComfyApp }
)?.app as ComfyApp) ?? null;

export const use = create<State>((set) => ({
output: [],
Expand All @@ -115,4 +132,152 @@ export namespace Comfy {
unlisteners: [],
setUnlisteners: (unlisteners) => set({ unlisteners }),
}));

export const registerListeners = async () => {
let api = get()?.api;

while (!api) {
await new Promise((resolve) => setTimeout(resolve, 1000));
api = get()?.api;
}

api.addEventListener("executed", async ({ detail }) => {
const { output, prompt_id } = detail;

console.log("executed_in_comfy_domain", detail);

const newInputs: Record<ID, Generation.Image.Input> = {};
const responses: Generation.Images = [];

const input = Generation.Image.Input.get(prompt_id);

const images = await Promise.all(
(output as ComfyOutput).images.map(async (image) => {
console.log("image", image);
const resp = await fetch(
`http://localhost:3000/view?filename=${image.filename}&subfolder=${
image.subfolder || ""
}&type=${image.type}`,
{
cache: "no-cache",
}
);

const blob = await resp.blob();
const url = URL.createObjectURL(blob);
console.log("url", url);

const output = Generation.Image.Output.get(prompt_id);

return {
id: ID.create(),
blob,
inputID: output?.inputID ?? "",
createdAt: new Date(),
};
})
);

for (const image of images) {
const inputID = ID.create();
const newInput = {
...Generation.Image.Input.initial(inputID),
...input,
seed: 0,
id: inputID,
};

const cropped = await cropImage(image, newInput);
if (!cropped) continue;

responses.push(cropped);
newInputs[inputID] = newInput;
}

Generation.Image.Inputs.set({
...Generation.Image.Inputs.get(),
...newInputs,
});
responses.forEach(Generation.Image.add);
Generation.Image.Output.received(prompt_id, responses);
});

api.addEventListener("execution_start", ({ detail }) => {
const { prompt_id } = detail;

console.log("execution_start", detail);

if (prompt_id) {
let input = Generation.Image.Input.get(prompt_id);
if (!input) {
input = Generation.Image.Input.initial(prompt_id);
Generation.Image.Inputs.set((inputs) => ({
...inputs,
[prompt_id]: input,
}));
}
const output = Generation.Image.Output.requested(
prompt_id,
{},
prompt_id
);
Generation.Image.Output.set(output);
}
});

api.addEventListener("execution_error", ({ detail }) => {
console.log("execution_error", detail);
Generation.Image.Output.clear(detail.prompt_id);
});

console.log("registered ComfyUI listeners");
};
}

function cropImage(
image: StableStudio.StableDiffusionImage,
input: Generation.Image.Input
) {
return new Promise<Generation.Image | void>((resolve) => {
const id = image.id;
const blob = image.blob;
if (!blob || !id) return resolve();

// crop image to box size
const croppedCanvas = document.createElement("canvas");
croppedCanvas.width = input.width;
croppedCanvas.height = input.height;

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const croppedCtx = croppedCanvas.getContext("2d")!;

const img = new window.Image();
img.src = URL.createObjectURL(blob);
img.onload = () => {
croppedCtx.drawImage(
img,
0,
0,
input.width,
input.height,
0,
0,
input.width,
input.height
);

croppedCanvas.toBlob((blob) => {
if (blob) {
const objectURL = URL.createObjectURL(blob);
resolve({
id,
inputID: input.id,
created: new Date(),
src: objectURL,
finishReason: 0,
});
}
});
};
});
}
Loading

0 comments on commit e7089b8

Please sign in to comment.