Update the Imaginate image generation API (#836)

* Update the Imaginate feature server API

* Change connection status strategy to prevent console errors

* Possible CORS fix? Maybe revert.

* Update to the final API and fix bugs
This commit is contained in:
Keavon Chambers 2022-11-02 17:20:29 -07:00
parent 9d56e86203
commit 5be59f7fce
3 changed files with 121 additions and 227 deletions

View File

@ -615,7 +615,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
}))],
ImaginateStatus::Generating => vec![WidgetHolder::new(Widget::TextButton(TextButton {
label: "Terminate".into(),
tooltip: "Cancel in-progress image generation and keep the latest progress".into(),
tooltip: "Cancel the in-progress image generation and keep the latest progress".into(),
on_update: WidgetCallback::new(|_| DocumentMessage::ImaginateTerminate.into()),
..Default::default()
}))],
@ -745,7 +745,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
},
LayoutGroup::Row {
widgets: {
let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40".to_string();
let tooltip = "Number of iterations to improve the image generation quality, with diminishing returns around 40 when using the Euler A sampling method".to_string();
vec![
WidgetHolder::new(Widget::TextLabel(TextLabel {
value: "Sampling Steps".into(),
@ -774,13 +774,7 @@ fn node_section_imaginate(imaginate_layer: &ImaginateLayer, layer: &Layer, persi
},
LayoutGroup::Row {
widgets: {
let tooltip = "
Algorithm used to generate the image during each sampling step.\n\
\n\
'DPM Fast' and 'DPM Adaptive' do not support live refreshing updates.
"
.trim()
.to_string();
let tooltip = "Algorithm used to generate the image during each sampling step".to_string();
let sampling_methods = ImaginateSamplingMethod::list();
let mut entries = Vec::with_capacity(sampling_methods.len());

View File

@ -584,12 +584,14 @@ export default defineComponent({
},
async activateEyedropperSample() {
// TODO: Replace this temporary solution that only works in Chromium-based browsers with the custom color sampler used by the Eyedropper tool
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if (!(window as any).EyeDropper) {
this.editor.instance.eyedropperSampleForColorPicker();
return;
}
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result = await new (window as any).EyeDropper().open();
this.setColorCode(result.sRGBHex);
} catch {

View File

@ -1,14 +1,15 @@
import { escapeJSON } from "@/utility-functions/escape";
/* eslint-disable camelcase */
// import { escapeJSON } from "@/utility-functions/escape";
import { blobToBase64 } from "@/utility-functions/files";
import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network";
import { stripIndents } from "@/utility-functions/strip-indents";
import { type Editor } from "@/wasm-communication/editor";
import type { XY } from "@/wasm-communication/messages";
import { type ImaginateGenerationParameters } from "@/wasm-communication/messages";
const MAX_POLLING_RETRIES = 4;
const SERVER_STATUS_CHECK_TIMEOUT = 5000;
const SAMPLING_MODES_POLLING_UNSUPPORTED = ["DPM fast", "DPM adaptive"];
const PROGRESS_EVERY_N_STEPS = 5;
let timer: NodeJS.Timeout | undefined;
let terminated = false;
@ -50,7 +51,7 @@ export async function imaginateGenerate(
// Begin polling for updates to the in-progress image generation at the specified interval
// Don't poll if the chosen interval is 0, or if the chosen sampling method does not support polling
if (refreshFrequency > 0 && !SAMPLING_MODES_POLLING_UNSUPPORTED.includes(parameters.samplingMethod)) {
if (refreshFrequency > 0) {
const interval = Math.max(refreshFrequency * 1000, 500);
scheduleNextPollingUpdate(interval, Date.now(), 0, editor, hostname, documentId, layerPath, parameters.resolution);
}
@ -62,9 +63,9 @@ export async function imaginateGenerate(
}
// Extract the final image from the response and convert it to a data blob
// Highly unstable API
const base64 = JSON.parse(body)?.data[0]?.[0] as string | undefined;
if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) throw new Error("Could not read final image result from server response");
const base64Data = JSON.parse(body)?.images?.[0] as string | undefined;
const base64 = typeof base64Data === "string" && base64Data.length > 0 ? `data:image/png;base64,${base64Data}` : undefined;
if (!base64) throw new Error("Could not read final image result from server response");
const blob = await (await fetch(base64)).blob();
// Send the backend an updated status
@ -156,9 +157,11 @@ function scheduleNextPollingUpdate(
try {
const [blob, percentComplete] = await pollImage(hostname);
// After waiting for the polling result back from the server, if during that intervening time the user has terminated the generation, exit so we don't overwrite that terminated status
if (terminated) return;
preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y);
if (blob) preloadAndSetImaginateBlobURL(editor, blob, documentId, layerPath, resolution.x, resolution.y);
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, percentComplete, "Generating");
scheduleNextPollingUpdate(interval, nextTimeoutBegan, 0, editor, hostname, documentId, layerPath, resolution);
@ -178,46 +181,30 @@ function scheduleNextPollingUpdate(
}
// API COMMUNICATION FUNCTIONS
// These are highly unstable APIs that will need to be updated very frequently, so we currently assume usage of this exact commit from the server:
// https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/7d6042b908c064774ee10961309d396eabdc6c4a
function endpoint(hostname: string): string {
// Highly unstable API
return `${hostname}api/predict/`;
}
async function pollImage(hostname: string): Promise<[Blob | undefined, number]> {
// Fetch the percent progress and in-progress image from the API
const result = await fetch(`${hostname}sdapi/v1/progress`, { signal: pollingAbortController.signal, method: "GET" });
const { current_image, progress } = await result.json();
async function pollImage(hostname: string): Promise<[Blob, number]> {
// Highly unstable API
const result = await fetch(endpoint(hostname), {
signal: pollingAbortController.signal,
headers: {
accept: "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
},
referrer: hostname,
referrerPolicy: "strict-origin-when-cross-origin",
body: stripIndents`
{
"fn_index":3,
"data":[],
"session_hash":"0000000000"
}`,
method: "POST",
mode: "cors",
credentials: "omit",
});
const json = await result.json();
// Highly unstable API
const percentComplete = Math.abs(Number(json.data[0].match(/(?<="width:).*?(?=%")/)[0])); // The API sometimes returns negative values presumably due to a bug
// Highly unstable API
const base64 = json.data[2];
// Convert to a usable format
const progressPercent = progress * 100;
const base64 = typeof current_image === "string" && current_image.length > 0 ? `data:image/png;base64,${current_image}` : undefined;
if (typeof base64 !== "string" || !base64.startsWith("data:image/png;base64,")) return Promise.reject();
// Deal with a missing image
if (!base64) {
// The image is not ready yet (because it's only had a few samples since generation began), but we do have a progress percentage
if (!Number.isNaN(progressPercent) && progressPercent >= 0 && progressPercent <= 100) {
return [undefined, progressPercent];
}
// Something else is wrong and the image wasn't provided as expected
return Promise.reject();
}
// The image was provided so we turn it into a data blob
const blob = await (await fetch(base64)).blob();
return [blob, percentComplete];
return [blob, progressPercent];
}
async function generate(
@ -231,138 +218,83 @@ async function generate(
xhr?: XMLHttpRequest;
}> {
let body;
let endpoint;
if (image === undefined || parameters.denoisingStrength === undefined) {
// Highly unstable API
body = stripIndents`
{
"fn_index":13,
"data":[
"${escapeJSON(parameters.prompt)}",
"${escapeJSON(parameters.negativePrompt)}",
"None",
"None",
${parameters.samples},
"${parameters.samplingMethod}",
${parameters.restoreFaces},
${parameters.tiling},
1,
1,
${parameters.cfgScale},
${parameters.seed},
-1,
0,
0,
0,
false,
${parameters.resolution.y},
${parameters.resolution.x},
false,
0.7,
0,
0,
"None",
false,
false,
null,
"",
"Seed",
"",
"Nothing",
"",
true,
false,
false,
null,
""
],
"session_hash":"0000000000"
}`;
endpoint = `${hostname}sdapi/v1/txt2img`;
body = {
// enable_hr: false,
// denoising_strength: 0,
// firstphase_width: 0,
// firstphase_height: 0,
prompt: parameters.prompt,
// styles: [],
seed: Number(parameters.seed),
// subseed: -1,
// subseed_strength: 0,
// seed_resize_from_h: -1,
// seed_resize_from_w: -1,
// batch_size: 1,
// n_iter: 1,
steps: parameters.samples,
cfg_scale: parameters.cfgScale,
width: parameters.resolution.x,
height: parameters.resolution.y,
restore_faces: parameters.restoreFaces,
tiling: parameters.tiling,
negative_prompt: parameters.negativePrompt,
// eta: 0,
// s_churn: 0,
// s_tmax: 0,
// s_tmin: 0,
// s_noise: 1,
override_settings: {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
},
sampler_index: parameters.samplingMethod,
};
} else {
const sourceImageBase64 = await blobToBase64(image);
// Highly unstable API
body = stripIndents`
{
"fn_index":33,
"data":[
0,
"${escapeJSON(parameters.prompt)}",
"${escapeJSON(parameters.negativePrompt)}",
"None",
"None",
"${sourceImageBase64}",
null,
null,
null,
"Draw mask",
${parameters.samples},
"${parameters.samplingMethod}",
4,
"fill",
${parameters.restoreFaces},
${parameters.tiling},
1,
1,
${parameters.cfgScale},
${parameters.denoisingStrength},
${parameters.seed},
-1,
0,
0,
0,
false,
${parameters.resolution.y},
${parameters.resolution.x},
"Just resize",
false,
32,
"Inpaint masked",
"",
"",
"None",
"",
true,
true,
"",
"",
true,
50,
true,
1,
0,
false,
4,
1,
"",
128,
8,
["left","right","up","down"],
1,
0.05,
128,
4,
"fill",
["left","right","up","down"],
false,
false,
null,
"",
"",
64,
"None",
"Seed",
"",
"Nothing",
"",
true,
false,
false,
null,
"",
""
],
"session_hash":"0000000000"
}`;
endpoint = `${hostname}sdapi/v1/img2img`;
body = {
init_images: [sourceImageBase64],
// resize_mode: 0,
denoising_strength: parameters.denoisingStrength,
// mask: "",
// mask_blur: 4,
// inpainting_fill: 0,
// inpaint_full_res: true,
// inpaint_full_res_padding: 0,
// inpainting_mask_invert: 0,
prompt: parameters.prompt,
// styles: [],
seed: Number(parameters.seed),
// subseed: -1,
// subseed_strength: 0,
// seed_resize_from_h: -1,
// seed_resize_from_w: -1,
// batch_size: 1,
// n_iter: 1,
steps: parameters.samples,
cfg_scale: parameters.cfgScale,
width: parameters.resolution.x,
height: parameters.resolution.y,
restore_faces: parameters.restoreFaces,
tiling: parameters.tiling,
negative_prompt: parameters.negativePrompt,
// eta: 0,
// s_churn: 0,
// s_tmax: 0,
// s_tmin: 0,
// s_noise: 1,
override_settings: {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
},
sampler_index: parameters.samplingMethod,
// include_init_images: false,
};
}
// Prepare a promise that will resolve after the outbound request upload is complete
@ -381,7 +313,7 @@ async function generate(
uploadedResolve();
}
};
const [result, xhr] = requestWithUploadDownloadProgress(endpoint(hostname), "POST", body, uploadProgress, abortAndResetPolling);
const [result, xhr] = requestWithUploadDownloadProgress(endpoint, "POST", JSON.stringify(body), uploadProgress, abortAndResetPolling);
result.catch(() => uploadedReject());
// Return the promise that resolves when the request upload is complete, the promise that resolves when the response download is complete, and the XHR so it can be aborted
@ -389,26 +321,7 @@ async function generate(
}
async function terminate(hostname: string): Promise<void> {
const body = stripIndents`
{
"fn_index":2,
"data":[],
"session_hash":"0000000000"
}`;
await fetch(endpoint(hostname), {
headers: {
accept: "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
},
referrer: hostname,
referrerPolicy: "strict-origin-when-cross-origin",
body,
method: "POST",
mode: "cors",
credentials: "omit",
});
await fetch(`${hostname}sdapi/v1/interrupt`, { method: "POST" });
}
async function checkConnection(hostname: string): Promise<boolean> {
@ -417,33 +330,18 @@ async function checkConnection(hostname: string): Promise<boolean> {
const timeout = setTimeout(() => statusAbortController.abort(), SERVER_STATUS_CHECK_TIMEOUT);
const body = stripIndents`
{
"fn_index":100,
"data":[],
"session_hash":"0000000000"
}`;
try {
await fetch(endpoint(hostname), {
signal: statusAbortController.signal,
headers: {
accept: "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
},
referrer: hostname,
referrerPolicy: "strict-origin-when-cross-origin",
body,
method: "POST",
mode: "cors",
credentials: "omit",
});
// Intentionally misuse this API endpoint by using it just to check for a code 200 response, regardless of what the result is
const { status } = await fetch(`${hostname}sdapi/v1/progress?skip_current_image=true`, { signal: statusAbortController.signal, method: "GET" });
clearTimeout(timeout);
return true;
} catch (_) {
return false;
// This code means the server has indeed responded and the endpoint exists (otherwise it would be 404)
if (status === 200) {
clearTimeout(timeout);
return true;
}
} catch {
// Do nothing here
}
return false;
}